├── Clipper.py ├── LICENSE ├── README.md ├── VehicleDC.py ├── bbox.py ├── car.cfg ├── car.data ├── car.names ├── checkpoints └── test ├── darknet.py ├── darknet_util.py ├── dataset.py ├── preprocess.py ├── test_imgs ├── test_0.jpg ├── test_1.jpg ├── test_10.jpg ├── test_11.jpg ├── test_12.jpg ├── test_13.jpg ├── test_14.jpg ├── test_15.jpg ├── test_16.jpg ├── test_17.jpg ├── test_18.jpg ├── test_19.jpg ├── test_2.jpg ├── test_20.jpg ├── test_3.jpg ├── test_4.jpg ├── test_5.jpg ├── test_6.jpg ├── test_7.jpg ├── test_8.jpg └── test_9.jpg ├── test_result ├── test_0.jpg ├── test_1.jpg ├── test_10.jpg ├── test_11.jpg ├── test_12.jpg ├── test_13.jpg ├── test_14.jpg ├── test_15.jpg ├── test_16.jpg ├── test_17.jpg ├── test_18.jpg ├── test_19.jpg ├── test_2.jpg ├── test_20.jpg ├── test_3.jpg ├── test_4.jpg ├── test_5.jpg ├── test_6.jpg ├── test_7.jpg ├── test_8.jpg └── test_9.jpg ├── train_vehicle_multilabel.py └── utils.py /Clipper.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import re 6 | import time 7 | import pickle 8 | import shutil 9 | import random 10 | import argparse 11 | 12 | # from darknet_util import * 13 | # from darknet import Darknet 14 | # from preprocess import prep_image, process_img, inp_to_image 15 | 16 | # import torch 17 | # import torchvision 18 | # import paramiko 19 | 20 | # import cv2 21 | import numpy as np 22 | import PIL 23 | from PIL import Image 24 | from matplotlib import pyplot as plt 25 | from matplotlib.widgets import Cursor 26 | from matplotlib.image import AxesImage 27 | # from scipy.spatial.distance import cityblock 28 | # from tqdm import tqdm 29 | 30 | # 为了使用matplotlib正确显示中文 31 | from pylab import * 32 | mpl.rcParams['font.sans-serif'] = ['SimHei'] 33 | 34 | # use_cuda = True # True 35 | # os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 36 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 37 | # device = torch.device( 38 | # 'cuda: 0' if torch.cuda.is_available() and use_cuda else 'cpu') 39 | 40 | # if use_cuda: 41 | # torch.manual_seed(0) 42 | # torch.cuda.manual_seed_all(0) 43 | # print('=> device: ', device) 44 | 45 | 46 | # 全局变量 47 | # root = 'e:/pick_car_roi' # 测试数据路径 48 | 49 | # model_path = 'e:/epoch_96.pth' 50 | # attrib_path = 'e:/vehicle_attributes.pkl' # 属性文件路径 51 | 52 | 53 | def letterbox_image(img, inp_dim): 54 | ''' 55 | resize image with unchanged aspect ratio using padding 56 | ''' 57 | img_w, img_h = img.shape[1], img.shape[0] 58 | w, h = inp_dim 59 | new_w = int(img_w * min(w / img_w, h / img_h)) 60 | new_h = int(img_h * min(w / img_w, h / img_h)) 61 | resized_image = cv2.resize( 62 | img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) 63 | 64 | canvas = np.full((inp_dim[1], inp_dim[0], 3), 128) 65 | canvas[(h - new_h) // 2:(h - new_h) // 2 + new_h, (w - new_w) // 66 | 2:(w - new_w) // 2 + new_w, :] = resized_image 67 | 68 | return canvas 69 | 70 | 71 | class Cropper(object): 72 | """ 73 | GUI交互, 通过鼠标键盘交互, 实现矩形抠图和拷贝 74 | """ 75 | 76 | def __init__(self, 77 | root, 78 | dst_dir, 79 | is_resume=False): 80 | """ 81 | 初始化资源 82 | @param root: 原图所在目录路径 83 | """ 84 | object.__init__(self) 85 | 86 | if not os.path.exists(root): 87 | print('[Err]: empty src dir.') 88 | return 89 | 90 | if not os.path.isdir(dst_dir): 91 | os.makedirs(dst_dir) 92 | 93 | self.root = root 94 | self.imgs_path = [os.path.join(self.root, x) 95 | for x in os.listdir(self.root)] 96 | self.dst_dir = dst_dir # 选取的ROI存放目录 97 | 98 | self.ROI = None 99 | 100 | self.clip_id = 0 101 | 102 | # 加载断点 103 | print('=> is resume: ', is_resume) 104 | if is_resume == 1: 105 | self.idx = pickle.load(open('clip_idx.pkl', 'rb')) 106 | self.label_dict = pickle.load(open('label_dict.pkl', 'rb')) 107 | print('=> resume from @%d, remain %d files to be classified.' % 108 | (self.idx, len(self.imgs_path) - self.idx - 1)) 109 | elif is_resume == 0: 110 | self.idx = 0 # 初始化序号 111 | self.label_dict = {} 112 | print('=> resume from @%d, remain %d files to be classified.' % 113 | (self.idx, len(self.imgs_path))) 114 | else: 115 | print('=> [Err]: unrecognized flag.') 116 | return 117 | 118 | # 初始化车辆多标签分类管理器 119 | # self.manager = Manager(model_path=model_path, 120 | # attrib_path=attrib_path) 121 | 122 | # 创建绘图 123 | self.fig = plt.figure(figsize=(14.0, 8.0)) 124 | self.ax = self.fig.add_subplot(111) 125 | 126 | # 为绘图添加鼠标和键盘callback 127 | self.cid_scroll = self.fig.canvas.mpl_connect( 128 | 'scroll_event', self.on_scroll) 129 | self.cid_btn_press = self.fig.canvas.mpl_connect( 130 | 'button_press_event', self.on_btn_press) 131 | self.cid_btn_release = self.fig.canvas.mpl_connect( 132 | 'button_release_event', self.on_btn_release) 133 | self.cid_mouse_move = self.fig.canvas.mpl_connect( 134 | 'motion_notify_event', self.on_mouse_motion) 135 | self.cid_key_release = self.fig.canvas.mpl_connect( 136 | 'key_release_event', self.on_key_release) 137 | 138 | # 初始化鼠标按键为False 139 | self.is_btn_press = False 140 | 141 | # 初始化鼠标点击次数为0 142 | self.is_rect_ready = False 143 | 144 | # 读取图像 145 | try: 146 | img_path = self.imgs_path[self.idx] 147 | print(img_path) 148 | except Exception as e: 149 | print(e) 150 | return 151 | self.img = Image.open(img_path) 152 | 153 | # 绘制光标定位 154 | self.cursor = Cursor(self.ax, 155 | useblit=True, 156 | color='red', 157 | linewidth=1) 158 | 159 | # 初始化矩形框 160 | self.init_rect() 161 | 162 | # 绘制第一张图 163 | ax_img = self.ax.imshow(self.img, picker=True) 164 | self.ax.set_xticks([]) 165 | self.ax.set_yticks([]) 166 | plt.title(img_path) 167 | plt.tight_layout() 168 | plt.show() 169 | 170 | self.fig.canvas.draw() 171 | 172 | def init_rect(self): 173 | """ 174 | 初始化矩形框 175 | """ 176 | self.is_btn_press = False 177 | self.is_rect_ready = False 178 | self.rect = Rectangle((0, 0), 1, 1, 179 | edgecolor='b', 180 | linewidth=1, 181 | facecolor='none') 182 | self.x_0, self.y_0, self.x_1, self.y_1 = 0, 0, 0, 0 183 | self.ax.add_patch(self.rect) 184 | 185 | def exit(self): 186 | """ 187 | 退出处理 188 | """ 189 | # 关闭图像 190 | self.ax.cla() 191 | self.fig.clf() 192 | plt.close() 193 | 194 | # 保存断点 195 | pickle.dump(self.idx, open('clip_idx.pkl', 'wb')) 196 | pickle.dump(self.label_dict, open('label_dict.pkl', 'wb')) 197 | print('=> save checkpoint idx @%d, and exit.' % self.idx) 198 | 199 | def update_fig(self): 200 | """ 201 | 更新绘图 202 | """ 203 | if self.idx < len(self.imgs_path): 204 | # 释放上一帧缓存 205 | self.ax.cla() 206 | 207 | # 重绘一帧图像 208 | self.img = Image.open(self.imgs_path[self.idx]) # 读取图像 209 | ax_img = self.ax.imshow(self.img, picker=True) 210 | self.ax.set_xticks([]) 211 | self.ax.set_yticks([]) 212 | plt.title(str(self.idx) + ': ' + self.imgs_path[self.idx]) 213 | plt.tight_layout() 214 | 215 | # 重新初始化矩形框 216 | self.init_rect() 217 | 218 | self.fig.canvas.draw() 219 | 220 | def draw_rect(self, event): 221 | self.x_1 = event.xdata 222 | self.y_1 = event.ydata 223 | if self.x_1 > 0 and self.x_1 < self.img.width: 224 | self.rect.set_width(self.x_1 - self.x_0) 225 | self.rect.set_height(self.y_1 - self.y_0) 226 | self.rect.set_xy((self.x_0, self.y_0)) 227 | self.fig.canvas.draw() 228 | 229 | def on_scroll(self, event): 230 | """ 231 | 鼠标滚动callback 232 | """ 233 | # 清空先前图像缓存 234 | self.ax.cla() 235 | 236 | if event.button == 'down' and event.step < -0.65: # 下一张图 237 | # 更新图像数据 238 | self.idx += 1 239 | elif event.button == 'up' and event.step > 0.65: # 前一张图 240 | if self.idx == 0: # 对于第一张图, 不存在前一张图 241 | print('[Note]: idx 0 image has no previous image.') 242 | return 243 | self.idx -= 1 244 | 245 | # 更新绘图 246 | self.update_fig() 247 | 248 | def on_btn_press(self, event): 249 | """ 250 | 鼠标按下callback 251 | """ 252 | # print('=> mouse btn press') 253 | self.x_0 = event.xdata 254 | self.y_0 = event.ydata 255 | self.is_btn_press = True 256 | 257 | def on_btn_release(self, event): 258 | """ 259 | 鼠标释放callback 260 | """ 261 | # print('=> mouse btn release') 262 | if self.is_rect_ready: # 如果是奇数次按下鼠标: 恢复鼠标未被按下的状态 263 | self.is_btn_press = False 264 | 265 | x_start = int(self.rect.get_x()) 266 | x_end = int(self.rect.get_x() + self.rect.get_width()) 267 | y_start = int(self.rect.get_y()) 268 | y_end = int(self.rect.get_y() + self.rect.get_height()) 269 | if x_start < x_end and y_start < y_end: 270 | self.ROI = Image.fromarray( 271 | np.array(self.img)[y_start: y_end, x_start: x_end]) 272 | elif x_start > x_end and y_start > y_end: 273 | self.ROI = Image.fromarray( 274 | np.array(self.img)[y_end: y_start, x_end: x_start]) 275 | 276 | if None != self.ROI: # ROI是 PIL Image, 对ROI进行预测 277 | # car_color, car_direction, car_type = self.manager.predict( 278 | # self.ROI) 279 | self.ROI.show() 280 | # print('=> predict:', car_color, car_direction, car_type) 281 | 282 | # 取反 283 | self.is_rect_ready = not self.is_rect_ready 284 | 285 | def on_mouse_motion(self, event): 286 | """ 287 | 鼠标移动callback 288 | """ 289 | # print('=> mouse moving...') 290 | 291 | if self.is_btn_press: 292 | if None == event.xdata or None == event.ydata: 293 | self.is_btn_press = False 294 | return 295 | self.draw_rect(event) 296 | 297 | def on_key_release(self, event): 298 | """ 299 | 键盘按键释放callback 300 | """ 301 | if event.key == 'c': # clip and save to destination dir 302 | date_name = time.strftime( 303 | '_%Y_%m_%d_', time.localtime(time.time())) 304 | 305 | self.clip_id += 1 306 | write_name = self.dst_dir + '/' + \ 307 | date_name + \ 308 | str(self.idx) + \ 309 | '_' + \ 310 | str(self.clip_id) + \ 311 | '.jpg' 312 | self.ROI.save(write_name) 313 | print('=> %s saved.' % write_name) 314 | 315 | # label = input('=> Enter label string:') # 手动输入label 316 | # self.label_dict[write_name.split('/')[-1]] = label 317 | # print('=> label: ', label) 318 | 319 | # 现在并不自动跳到下一帧 320 | # self.idx += 1 321 | # self.update_fig() 322 | elif event.key == 'e': # 退出程序 323 | self.exit() 324 | self.is_btn_press = False 325 | 326 | 327 | # ----------------------------------------------------------- 328 | 329 | # 网络模型 330 | # class Net(torch.nn.Module): 331 | # """ 332 | # power-set车辆多标签分类 333 | # """ 334 | 335 | # def __init__(self, num_cls, input_size): 336 | # """ 337 | # 网络定义 338 | # :param is_freeze: 339 | # """ 340 | # torch.nn.Module.__init__(self) 341 | 342 | # # 输出通道数 343 | # self._num_cls = num_cls 344 | 345 | # # 输入图像尺寸 346 | # self.input_size = input_size 347 | 348 | # # 删除原有全连接, 得到特征提取层 349 | # self.features = torchvision.models.resnet18(pretrained=True) 350 | # del self.features.fc 351 | # # print('feature extractor:\n', self.features) 352 | 353 | # self.features = torch.nn.Sequential( 354 | # *list(self.features.children())) 355 | 356 | # # 重新定义全连接层 357 | # self.fc = torch.nn.Linear(512 ** 2, num_cls) # 输出类别数 358 | # # print('=> fc layer:\n', self.fc) 359 | 360 | # def forward(self, X): 361 | # """ 362 | # :param X: 363 | # :return: 364 | # """ 365 | # N = X.size()[0] 366 | 367 | # X = self.features(X) # extract features 368 | 369 | # X = X.view(N, 512, 1 ** 2) 370 | # X = torch.bmm(X, torch.transpose(X, 1, 2)) / (1 ** 2) # Bi-linear 371 | 372 | # X = X.view(N, 512 ** 2) 373 | # X = torch.sqrt(X + 1e-5) 374 | # X = torch.nn.functional.normalize(X) 375 | # X = self.fc(X) 376 | # assert X.size() == (N, self._num_cls) # 输出类别数 377 | # return X 378 | 379 | 380 | # 封装管理 381 | # class Manager(object): 382 | # """ 383 | # 模型初始化等 384 | # """ 385 | 386 | # def __init__(self, 387 | # model_path, 388 | # attrib_path): 389 | # """ 390 | # 加载模型并初始化 391 | # """ 392 | 393 | # # 定义模型, 放入device, 加载权重 394 | # self.net = Net(num_cls=23, 395 | # input_size=224).to(device) 396 | 397 | # # self.net = torch.nn.DataParallel(Net(num_cls=23, input_size=224), 398 | # # device_ids=[0]).to(device) 399 | 400 | # self.net.load_state_dict(torch.load(model_path)) 401 | # print('=> vehicle classifier loaded from %s' % model_path) 402 | 403 | # # 设置模型为测试模式 404 | # self.net.eval() 405 | 406 | # # 测试数据预处理方式 407 | # self.transforms = torchvision.transforms.Compose([ 408 | # torchvision.transforms.Resize(size=224), 409 | # torchvision.transforms.CenterCrop(size=224), 410 | # torchvision.transforms.ToTensor(), 411 | # torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 412 | # std=(0.229, 0.224, 0.225)) 413 | # ]) 414 | 415 | # # 加载attributes向量 416 | # self.attributes = pickle.load(open(attrib_path, 'rb')) 417 | # self.attributes = [str(x) for x in self.attributes] 418 | # # print('=> training attributes:\n', attributes) 419 | 420 | # # 将多标签分开 421 | # self.color_attrs = self.attributes[:11] 422 | # del self.color_attrs[5] 423 | # print('=> color_attrs:\n', self.color_attrs) 424 | 425 | # self.direction_attrs = self.attributes[11:14] 426 | # del self.direction_attrs[2] 427 | # print('=> direction attrs:\n', self.direction_attrs) 428 | 429 | # self.type_attrs = self.attributes[14:] 430 | # del self.type_attrs[6] 431 | # print('=> type_attrs:\n', self.type_attrs) 432 | 433 | # def get_predict_ce(self, output): 434 | # """ 435 | # softmax归一化,然后统计每一个标签最大值索引 436 | # :param output: 437 | # :return: 438 | # """ 439 | # # 计算预测值 440 | # output = output.cpu() # 从GPU拷贝出到host端 441 | # pred_color = output[:, :11] 442 | # pred_direction = output[:, 11:14] 443 | # pred_type = output[:, 14:] 444 | 445 | # color_idx = pred_color.max(1, keepdim=True)[1] 446 | # direction_idx = pred_direction.max(1, keepdim=True)[1] 447 | # type_idx = pred_type.max(1, keepdim=True)[1] 448 | 449 | # # 连接pred 450 | # pred = torch.cat((color_idx, direction_idx, type_idx), dim=1) 451 | # return pred 452 | 453 | # def get_predict(self, output): 454 | # """ 455 | # 新输出向量(20维)的处理 456 | # """ 457 | # # 计算预测值 458 | # output = output.cpu() # 从GPU拷贝出到host端 459 | # pred_color = output[:, :10] 460 | # pred_direction = output[:, 10:12] 461 | # pred_type = output[:, 12:] 462 | 463 | # color_idx = pred_color.max(1, keepdim=True)[1] 464 | # direction_idx = pred_direction.max(1, keepdim=True)[1] 465 | # type_idx = pred_type.max(1, keepdim=True)[1] 466 | 467 | # # 连接pred 468 | # pred = torch.cat((color_idx, direction_idx, type_idx), dim=1) 469 | # return pred 470 | 471 | # def pre_process(self, image): 472 | # """ 473 | # 图像数据类型转换 474 | # :rtype: PIL.JpegImagePlugin.JpegImageFile 475 | # """ 476 | # # 数据预处理 477 | # if type(image) == np.ndarray: 478 | # if image.shape[2] == 3: # 3通道转换成RGB 479 | # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 480 | # elif image.shape[2] == 1: # 单通道, 灰度转换成RGB 481 | # image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 482 | 483 | # # numpy.ndarray转换成PIL.Image 484 | # image = Image.fromarray(image) 485 | # elif type(image) == PIL.JpegImagePlugin.JpegImageFile: 486 | # if image.mode == 'L' or image.mode == 'I': # 8bit或32bit单通道灰度图转换成RGB 487 | # image = image.convert('RGB') 488 | 489 | # return image 490 | 491 | # def predict(self, img): 492 | # """ 493 | # 预测属性: 输入图像通过PIL读入的 494 | # :return:返回预测的车辆颜色、车辆朝向、车辆类别 495 | # """ 496 | # # 数据预处理 497 | # img = self.transforms(img) 498 | # img = img.view(1, 3, 224, 224) 499 | 500 | # # 图像数据放入device运行 501 | # img = img.to(device) 502 | 503 | # # 前向运算 504 | # output = self.net.forward(img) 505 | 506 | # # 获取预测结果 507 | # try: 508 | # pred = self.get_predict(output) # self.get_predict_ce, 返回的pred在host端 509 | # color_name = self.color_attrs[pred[0][0]] 510 | # direction_name = self.direction_attrs[pred[0][1]] 511 | # type_name = self.type_attrs[pred[0][2]] 512 | # except Exception as e: 513 | # return None, None, None 514 | 515 | # return color_name, direction_name, type_name 516 | 517 | 518 | def test(is_pil=True): 519 | """ 520 | 单元测试和可视化 521 | :return: 522 | """ 523 | # 测试数据路径 524 | root = 'e:/pick_car_roi' 525 | model_path = 'e:/epoch_42.pth' 526 | attrib_path = 'e:/vehicle_attributes.pkl' 527 | 528 | # 模型初始化 529 | manager = Manager(model_path=model_path, attrib_path=attrib_path) 530 | 531 | for file in os.listdir(root): 532 | # 读取测试数据 533 | file_path = os.path.join(root, file) 534 | 535 | if is_pil: 536 | image = Image.open(file_path) # 通过PIL读取图像 537 | else: 538 | image = cv2.imread(file_path, cv2.IMREAD_UNCHANGED) # 通过opencv读取图像 539 | 540 | # ------------------------------- 541 | # 图像数据格式预处理 542 | image = manager.pre_process(image) 543 | 544 | # 预测 545 | car_color, car_direction, car_type = manager.predict(image) 546 | # ------------------------------- 547 | 548 | # 可视化 549 | fig = plt.figure(figsize=(6, 6)) 550 | plt.imshow(image) 551 | plt.title(car_color + ' ' + car_direction + ' ' + car_type) 552 | plt.gca().set_xticks([]) 553 | plt.gca().set_yticks([]) 554 | plt.show() 555 | 556 | 557 | class Car_DR(): 558 | def __init__(self, 559 | src_dir, 560 | dst_dir, 561 | car_cfg_path='./car.cfg', 562 | car_det_weights_path='g:/Car_DR/car_360000.weights', 563 | inp_dim=768, 564 | prob_th=0.2, 565 | nms_th=0.4, 566 | num_classes=1): 567 | """ 568 | 模型初始化 569 | """ 570 | # 超参数 571 | self.inp_dim = inp_dim 572 | self.prob_th = prob_th 573 | self.nms_th = nms_th 574 | self.num_classes = num_classes 575 | self.dst_dir = dst_dir 576 | 577 | # 清空dst_dir 578 | if os.path.exists(self.dst_dir): 579 | for x in os.listdir(self.dst_dir): 580 | if x.endswith('.jpg'): 581 | os.remove(self.dst_dir + '/' + x) 582 | else: 583 | os.makedirs(self.dst_dir) 584 | 585 | # 初始化车辆检测模型及参数 586 | self.Net = Darknet(car_cfg_path) 587 | self.Net.load_weights(car_det_weights_path) 588 | self.Net.net_info['height'] = self.inp_dim # 车辆检测输入分辨率 589 | self.Net.to(device) 590 | self.Net.eval() # 测试模式 591 | print('=> car detection model initiated.') 592 | 593 | # 初始化车辆多标签分类管理器 594 | self.manager = Manager(model_path=model_path, attrib_path=attrib_path) 595 | 596 | # 统计src_dir文件 597 | self.imgs_path = [os.path.join(src_dir, x) for x in os.listdir( 598 | src_dir) if x.endswith('.jpg')] 599 | 600 | def cls_draw_bbox(self, output, orig_img): 601 | """ 602 | orig_img是通过opencv读取的numpy array格式: 通道顺序BGR 603 | 在bbox基础上预测车辆属性 604 | 将bbox绘制到原图上 605 | """ 606 | labels = [] 607 | pt_1s = [] 608 | pt_2s = [] 609 | 610 | # 获取车辆属性labels 611 | for det in output: 612 | # rectangle points 613 | pt_1 = tuple(det[1:3].int()) # the left-up point 614 | pt_2 = tuple(det[3:5].int()) # the right down point 615 | pt_1s.append(pt_1) 616 | pt_2s.append(pt_2) 617 | 618 | # 调用分类器预测车辆属性: BGR => RGB 619 | ROI = Image.fromarray( 620 | orig_img[pt_1[1]: pt_2[1], 621 | pt_1[0]: pt_2[0]][:, :, ::-1]) 622 | # ROI.show() 623 | 624 | car_color, car_direction, car_type = self.manager.predict(ROI) 625 | label = str(car_color + ' ' + car_direction + ' ' + car_type) 626 | labels.append(label) 627 | print('=> predicted label: ', label) 628 | 629 | # 将bbox绘制到原图 630 | color = (0, 215, 255) 631 | for i, det in enumerate(output): 632 | pt_1 = pt_1s[i] 633 | pt_2 = pt_2s[i] 634 | 635 | # 绘制bounding box 636 | cv2.rectangle(orig_img, pt_1, pt_2, color, thickness=2) 637 | 638 | # 获取文本大小 639 | txt_size = cv2.getTextSize( 640 | label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0] # 文字大小 641 | # pt_2 = pt_1[0] + txt_size[0] + 3, pt_1[1] + txt_size[1] + 5 642 | pt_2 = pt_1[0] + txt_size[0] + 3, pt_1[1] - txt_size[1] - 5 643 | 644 | # 绘制文本底色矩形 645 | cv2.rectangle(orig_img, pt_1, pt_2, color, thickness=-1) # text 646 | 647 | # 绘制文本 648 | cv2.putText(orig_img, labels[i], (pt_1[0], pt_1[1]), # pt_1[1] + txt_size[1] + 4 649 | cv2.FONT_HERSHEY_PLAIN, 2, [225, 255, 255], 2) 650 | 651 | def cls_and_draw(self, output, orig_img): 652 | """ 653 | orig_img是PIL Image图像格式 654 | 在bbox基础上预测车辆属性 655 | 将bbox绘制到原图上 656 | """ 657 | labels = [] 658 | x_ys = [] 659 | w_hs = [] 660 | 661 | # 获取车辆属性labels 662 | for det in output: 663 | # rectangle 664 | x_y = tuple(det[1:3].int()) # x, y 665 | w_h = tuple(det[3:5].int()) # w, h 666 | x_ys.append(x_y) 667 | w_hs.append(w_h) 668 | 669 | # 调用分类器预测车辆属性: BGR => RGB 670 | box = (int(x_y[0]), int(x_y[1]), int(x_y[0] + w_h[0]), 671 | int(x_y[1] + w_h[1])) # left, upper, right, lower 672 | ROI = orig_img.crop(box) 673 | 674 | car_color, car_direction, car_type = self.manager.predict(ROI) 675 | label = car_color + ' ' + car_direction + ' ' + car_type 676 | print('=> label: ', label) 677 | labels.append(label) 678 | 679 | # 将bbox绘制到原图 680 | for i, det in enumerate(output): 681 | x_y = x_ys[i] 682 | w_h = w_hs[i] 683 | 684 | color = (0, 215, 255) 685 | cv2.rectangle(np.asarray(orig_img), x_y, w_h, color, 686 | thickness=2) # bounding box 687 | 688 | txt_size = cv2.getTextSize( 689 | label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0] # 文字大小 690 | w_h = x_y[0] + txt_size[0] + 4, x_y[1] + txt_size[1] + 4 691 | cv2.rectangle(np.asarray(orig_img), x_y, w_h, 692 | color, thickness=-1) # text 693 | cv2.putText(np.asarray(orig_img), labels[i], (x_y[0], x_y[1] + txt_size[1] + 4), 694 | cv2.FONT_HERSHEY_PLAIN, 2, [225, 255, 255], 2) 695 | 696 | def predict(self): 697 | """ 698 | 批量检测和识别, 将检测, 识别结果输出到dst_dir 699 | """ 700 | for x in self.imgs_path: 701 | # 读取图像数据 702 | img = Image.open(x) 703 | img2det = process_img(img, self.inp_dim) 704 | img2det = img2det.to(device) # 图像数据放到device 705 | 706 | # 车辆检测 707 | prediction = self.Net.forward(img2det, CUDA=True) 708 | 709 | # 计算scaling factor 710 | orig_img_size = list(img.size) 711 | output = process_predict(prediction, 712 | self.prob_th, 713 | self.num_classes, 714 | self.nms_th, 715 | self.inp_dim, 716 | orig_img_size) 717 | 718 | orig_img = cv2.cvtColor(np.asarray( 719 | img), cv2.COLOR_RGB2BGR) # RGB => BGR 720 | if type(output) != int: 721 | # 将检测框bbox绘制到原图上 722 | # draw_car_bbox(output, orig_img) 723 | self.cls_draw_bbox(output, orig_img) 724 | # self.cls_and_draw(output, img) 725 | dst_path = self.dst_dir + '/' + os.path.split(x)[1] 726 | if not os.path.exists(dst_path): 727 | cv2.imwrite(dst_path, orig_img) 728 | 729 | # ----------------------------------------------------------- 730 | 731 | 732 | def test_car_detect(car_cfg_path='./car.cfg', 733 | car_det_weights_path='g:/Car_DR/car_360000.weights'): 734 | """ 735 | imgs_path: 图像数据路径 736 | """ 737 | inp_dim = 768 738 | prob_th = 0.2 # 车辆检测概率阈值 739 | nms_th = 0.4 # NMS阈值 740 | num_cls = 1 # 只检测车辆1类 741 | 742 | # 初始化车辆检测模型及参数 743 | Net = Darknet(car_cfg_path) 744 | Net.load_weights(car_det_weights_path) 745 | Net.net_info['height'] = inp_dim # 车辆检测输入分辨率 746 | Net.to(device) 747 | Net.eval() # 测试模式 748 | print('=> car detection model initiated.') 749 | 750 | # 读取图像数据 751 | img = Image.open( 752 | 'f:/FaceRecognition_torch_0_4/imgs_21/det_2018_08_21_63_1.jpg') 753 | img2det = process_img(img, inp_dim) 754 | img2det = img2det.to(device) # 图像数据放到device 755 | 756 | # 测试车辆检测 757 | prediction = Net.forward(img2det, CUDA=True) 758 | 759 | # 计算scaling factor 760 | orig_img_size = list(img.size) 761 | output = process_predict(prediction, 762 | prob_th, 763 | num_cls, 764 | nms_th, 765 | inp_dim, 766 | orig_img_size) 767 | 768 | orig_img = np.asarray(img) 769 | if type(output) != int: 770 | # 将检测框bbox绘制到原图上 771 | draw_car_bbox(output, orig_img) 772 | 773 | cv2.imshow('test', orig_img) 774 | cv2.waitKey() 775 | 776 | 777 | """ 778 | # prep_ret = prep_image('f:/FaceRecognition_torch_0_4/imgs_21/det_2018_08_21_63_1.jpg', 779 | # inp_dim) # 返回一个Tensor 780 | # img2det = prep_ret[0].view(1, 3, inp_dim, inp_dim) 781 | # Net.load_state_dict(torch.load('./car_detect_model.pth')) 782 | """ 783 | 784 | 785 | def draw_car_bbox(output, orig_img): 786 | for det in output: 787 | label = 'car' # 类型名称 788 | prob = '{:.3f}'.format(det[5].cpu().numpy()) 789 | label += prob 790 | 791 | x_y = tuple(det[1:3].int()) # x, y 792 | w_h = tuple(det[3:5].int()) # w, h 793 | 794 | color = (0, 215, 255) 795 | cv2.rectangle(orig_img, x_y, w_h, color, 796 | thickness=2) # bounding box 797 | 798 | txt_size = cv2.getTextSize( 799 | label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0] # 文字大小 800 | w_h = x_y[0] + txt_size[0] + 3, x_y[1] + txt_size[1] + 4 801 | cv2.rectangle(orig_img, x_y, w_h, color, thickness=-1) # text 802 | cv2.putText(orig_img, label, (x_y[0], x_y[1] + txt_size[1] + 4), 803 | cv2.FONT_HERSHEY_PLAIN, 2, [225, 255, 255], 2) 804 | 805 | 806 | def process_predict(prediction, 807 | prob_th, 808 | num_cls, 809 | nms_th, 810 | inp_dim, 811 | orig_img_size): 812 | """ 813 | 处理预测结果 814 | """ 815 | scaling_factor = min([inp_dim / float(x) 816 | for x in orig_img_size]) # W, H缩放系数 817 | output = post_process(prediction, 818 | prob_th, 819 | num_cls, 820 | nms=True, 821 | nms_conf=nms_th, 822 | CUDA=True) # post-process such as nms 823 | 824 | if type(output) != int: 825 | output[:, [1, 3]] -= (inp_dim - scaling_factor * 826 | orig_img_size[0]) / 2.0 # x, w 827 | output[:, [2, 4]] -= (inp_dim - scaling_factor * 828 | orig_img_size[1]) / 2.0 # y, h 829 | output[:, 1:5] /= scaling_factor 830 | for i in range(output.shape[0]): 831 | output[i, [1, 3]] = torch.clamp( 832 | output[i, [1, 3]], 0.0, orig_img_size[0]) 833 | output[i, [2, 4]] = torch.clamp( 834 | output[i, [2, 4]], 0.0, orig_img_size[1]) 835 | return output 836 | 837 | 838 | def test_equal(f_path_1, f_path_2): 839 | """ 840 | f_path_1: 第一个文件路径 841 | f_path_2: 第二个文件路径 842 | """ 843 | arr_1 = np.load(f_path_1)['arr_0'] 844 | arr_2 = np.load(f_path_2)['arr_0'] 845 | 846 | # 判断两个数组是否逐元素相等 847 | print('=> the two array is equal:', (arr_1 == arr_2).all()) 848 | 849 | 850 | # --------------------------------将clipper处理的数据合并回vehicle_train 851 | def process_clipped(src_root, dst_root): 852 | """ 853 | 将src_root中的数据按照label合并到dst_root对应子目录 854 | """ 855 | # 加载label_dict 856 | # label_path = src_root + '/' + 'label_dict.pkl' 857 | try: 858 | label_dict = pickle.load( 859 | open('f:/FaceRecognition_torch_0_4/label_dict.pkl', 'rb')) 860 | # print(label_dict) 861 | except Exception as e: 862 | print(e) 863 | 864 | # 遍历src_root 865 | for x in os.listdir(src_root): 866 | if x.endswith('.jpg'): # 只处理存在的jpg图 867 | if x in label_dict.keys(): # 只处理存在key的数据 868 | label = label_dict[x] 869 | # print('=> key: %s, value: %s' % (x, label)) 870 | sub_dir_path = dst_root + '/' + label.replace(' ', '_') 871 | # print(sub_dir_path) 872 | 873 | # 如果src, dst文件存在才合并 874 | if os.path.isdir(sub_dir_path): 875 | src_path = src_root + '/' + x 876 | if os.path.exists(src_path): 877 | dst_path = sub_dir_path + '/' + x 878 | if not os.path.exists(dst_path): # 如果已经存, 则不再拷贝 879 | shutil.copy(src_path, sub_dir_path) 880 | print('=> %s copied to %s' % 881 | (src_path, sub_dir_path)) 882 | 883 | # ---------------------------- 884 | 885 | 886 | def viz_err(err_path, root='f:/'): 887 | """ 888 | 可视化分类错误信息 889 | """ 890 | err_dict = pickle.load(open(err_path, 'rb')) 891 | # print(err_dict) 892 | 893 | fig = plt.figure() # 894 | 895 | for k, v in err_dict.items(): 896 | img_path = root + k 897 | if os.path.isfile(img_path): 898 | img = Image.open(img_path) 899 | plt.gcf().set_size_inches(8, 8) 900 | plt.imshow(img) 901 | plt.title(img_path + '\n' + v) 902 | plt.gca().set_xticks([]) 903 | plt.gca().set_yticks([]) 904 | plt.show() 905 | 906 | 907 | if __name__ == '__main__': 908 | # ---------------------------- Clip roi, labeling and copy 909 | parser = argparse.ArgumentParser(description='Cropper parameters') 910 | parser.add_argument('-src', 911 | type=str, 912 | dest='s', 913 | default=u'f:/LPVehicleID_1/', 914 | help='dir path of JPEGImages') 915 | parser.add_argument('-dst', 916 | type=str, 917 | dest='d', 918 | default=u'f:/LPVehicleID_pro/', 919 | help='dir path of JPEGImages') 920 | parser.add_argument('-folder', 921 | type=str, 922 | dest='f', 923 | default=u'桂A66K53', 924 | help='dir path of JPEGImages') 925 | parser.add_argument('-r', 926 | type=int, 927 | default=0, 928 | help='dir path of JPEGImages') 929 | args = parser.parse_args() 930 | 931 | cropper = Cropper(root=args.s + args.f, 932 | dst_dir=args.d + args.f, 933 | is_resume=args.r) 934 | 935 | # process_clipped(src_root=u'f:/LPVehicleID_1/川A1D695', 936 | # dst_root='f:/vehicle_train') 937 | 938 | # ---------------------------- 939 | # test_car_detect() 940 | 941 | # ---------------------------- Car detect and classify 942 | # DR_model = Car_DR(src_dir='g:/car_0819', 943 | # dst_dir='f:/test_result') 944 | # DR_model.predict() 945 | 946 | # ---------------------------- 947 | # test_equal('e:/prediction_1.npz', 'e:/prediction_2.npz'c) 948 | # test() 949 | 950 | # ---------------------------- 951 | 952 | # viz_err('g:/err_dict.pkl') 953 | print('=> Test done.') 954 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Even 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vehicle-Car-detection-and-multilabel-classification 车辆检测和多标签属性识别 2 | ## 一个基于Pytorch精简的框架,使用YOLO_v3_tiny和B-CNN实现街头车辆的检测和车辆属性的多标签识别。
(A precise pytorch based framework for using yolo_v3_tiny to do vehicle or car detection and attribute's multilabel classification or recognize) 3 | 4 | ## 效果如下: Vehicle detection and recognition results are as follows:
5 | ![](https://github.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/blob/master/test_result/test_5.jpg) 6 | ![](https://github.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/blob/master/test_result/test_17.jpg) 7 |
8 | 9 | ## 使用方法 Usage 10 | python Vehicle_DC -src_dir your_imgs_dir -dst_dir your_result_dir 11 | 12 | ## 训练好的模型文件(包括车辆检测模型和多标签分类模型) trained models on baidu drive 13 | [Tranied models-vehicle detection](https://pan.baidu.com/s/1HwTCVGTmdqkeLnqnxfNL8Q)
14 | [Tranied models-vehicle classification](https://pan.baidu.com/s/1XmzjvCgOrrVv0NWTt4Fm3g)
15 | 在运行Vehicle_DC脚本之前,先下载上面的模型文件或者使用自己预先训练好的模型文件,将car_540000.weights(用于检测)放在项目根目录,将epoch_39.pth(用于多标签识别)放在根目录下的checkpoints目录下,即可使用Vehicle_DC运行。
16 | Before running Vehicle_DC, you should download provided model files provided above or use your own pretrained models. If using models provided, you need to place car_540000.weights on root directory of this project, and place epoch_39.pth on root/checkpoints/. 17 | 18 | ### 程序简介 brief introductions 19 | #### (1). 程序包含两大模块:
The program consists of two parts: first, car detection(only provides model loading and inference code, if you need training code, you can refer to [pytorch_yolo_v3](https://github.com/eriklindernoren/PyTorch-YOLOv3#train)); the car attributes classiyfing(provide both training and testing code, it will predict a vehicle's body color, body direction and car type) 20 | ##### <1>. 车辆检测模块: 只提供检测, 训练代码可以参考[pytorch_yolo_v3](https://github.com/eriklindernoren/PyTorch-YOLOv3#train);
21 | ##### <2>. 多标签识别模块:包含车辆颜色、车辆朝向、车辆类型 22 | 将这两个模块结合在一起,可以同时实现车辆的检测和识别。以此为基础,可以对室外智能交通信息,进行一定程度的结构化信息提取。
23 | Combining these two modules together, you can do vehicle detection and multi-label recognization at the same time. Based on this info, some structured infos in outdoor traffic scenes can be extracted. 24 | #### (2). 程序模块详解 modules detailed introduction
25 | ##### <1>. VehicleDC.py
26 | 此模块是车辆检测和车辆多标签识别接口的封装,需要指定测试源目录和结果输出目录。主类Car_DC, 函数__init__主要负责汽车检测、汽车识别两个模型的初始化。 27 | 函数detect_classify负责逐张对图像进行检测和识别:首先对输入图像进行预处理,统一输入格式,然后,输出该图像所有的车的检测框。通过函数process_predict做nms, 坐标系转换,得到所有最终的检测框。然后,程序调用函数cls_draw_bbox,在cls_draw_bbox中,逐一处理每个检测框。首先,取出原图像检测框区域检测框对应的的ROI(region of interest), 将ROI送入车辆多标签分类器。分类器调用B-CNN算法对ROI中的车辆进行多标签属性分类。参考[paper link](http://vis-www.cs.umass.edu/bcnn/docs/bcnn_iccv15.pdf)。B-CNN主要用于训练端到端的细粒度分类。本程序对论文中的网络结构做了一定的适应性修改:为了兼顾程序的推断速度和准确度,不同于论文中采用的Vgg-16,这里的B-CNN的基础网络采用Resnet-18。
28 | This module is responsible for interface encapsulation of vehicle detection and multi-label classification. You need to specify source directory and result directory. The main class is Car_DC. The pretrained models are loaded and initiated in function init(). In function detect_classify, each input image is pre-processed to get uniformed format, then output the raw bounding boxes for further NMS calculation and coordinates tranformation. We do classification and bounding box drawing in function cls_draw_box based on bounding box ROIs. Bilinear CNN is used for fine-grained classification, and we use resnet-18 as backbone insted of vgg-16 for trade-off of accuracy and speed. 29 | ##### 耗时统计耗时 Time consuming 30 | 车辆检测: 单张图像推断耗时,在单个GTX 1050TI GPU上约18ms。
31 | 车辆多标签识别:单张图像推断耗时,在单个GTX TITAN GPU上约7ms,在单个GTX 1050TI GPU上约10ms。
32 | Vehicle detection: sigle image inference cost 18ms on single GTX1050TI.
33 | Vehicle classification: single image inference cost 10ms on single GTX1050TI. 34 | 35 | ##### <2>. 车辆多标签数据模块(由于保密协议等原因暂时不能公开数据集) dataset.py
36 | 训练、测试数据类别按照子目录存放,子目录名即label,Color_Direction_type,如Yellow_Rear_suv。
37 | Vehicle类重载了data.Dataset的init, getitem, len方法:
38 | 函数__init__负责初始化数据路径,数据标签,由于数据标签是多标签类型,故对输出向量分段计算交叉熵loss即可。
39 | 函数__getitem__负责迭代返回数据和标签,返回的数据需要经过标准化等预处理;函数__len__获取数据的总数量。 40 | 41 | ##### <3>. 车辆多标签训练、测试模块 train_vehicle_multilabel.py 42 | 此模块负责车辆多标签的训练和测试。训练过程选择交叉熵作为损失函数,需要注意的是,由于是多标签分类,故计算loss的时候需要累加各个标签的loss,其中loss = loss_color + loss_direction + 2.0 * loss_type,根据经验,将车辆类型的loss权重放到到2倍效果较好。 43 |
44 | 另一方面,训练分为两步:(1). 冻结除了Resnet-18除全连接层之外的所有层,Fine-tune训练到收敛为止;(2).打开第一步中冻结的所有层,进一步Fine-tune训练,调整所有层的权重,直至整个模型收敛为止。 45 | -------------------------------------------------------------------------------- /VehicleDC.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import sys 5 | import re 6 | import time 7 | import pickle 8 | import shutil 9 | import random 10 | import argparse 11 | 12 | from darknet_util import * 13 | from darknet import Darknet 14 | from preprocess import prep_image, process_img, inp_to_image 15 | from dataset import color_attrs, direction_attrs, type_attrs 16 | 17 | import torch 18 | import torchvision 19 | import paramiko 20 | import cv2 21 | import numpy as np 22 | import PIL 23 | from PIL import Image 24 | from matplotlib import pyplot as plt 25 | from matplotlib.widgets import Cursor 26 | from matplotlib.image import AxesImage 27 | from scipy.spatial.distance import cityblock 28 | from tqdm import tqdm 29 | 30 | # ------------------------------------- 31 | # for matplotlib to displacy chinese characters correctly 32 | from pylab import * 33 | mpl.rcParams['font.sans-serif'] = ['SimHei'] 34 | 35 | use_cuda = True # True 36 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 37 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 38 | device = torch.device( 39 | 'cuda: 0' if torch.cuda.is_available() and use_cuda else 'cpu') 40 | 41 | if use_cuda: 42 | torch.manual_seed(0) 43 | torch.cuda.manual_seed_all(0) 44 | print('=> device: ', device) 45 | 46 | local_model_path = './checkpoints/epoch_39.pth' 47 | local_car_cfg_path = './car.cfg' 48 | local_car_det_weights_path = './car_detect.weights' 49 | 50 | 51 | 52 | class Cls_Net(torch.nn.Module): 53 | """ 54 | vehicle multilabel classification model 55 | """ 56 | 57 | def __init__(self, num_cls, input_size): 58 | """ 59 | network definition 60 | :param is_freeze: 61 | """ 62 | torch.nn.Module.__init__(self) 63 | 64 | # output channels 65 | self._num_cls = num_cls 66 | 67 | # input image size 68 | self.input_size = input_size 69 | 70 | # delete original FC and add custom FC 71 | self.features = torchvision.models.resnet18(pretrained=True) 72 | del self.features.fc 73 | # print('feature extractor:\n', self.features) 74 | 75 | self.features = torch.nn.Sequential( 76 | *list(self.features.children())) 77 | 78 | self.fc = torch.nn.Linear(512 ** 2, num_cls) # 输出类别数 79 | # print('=> fc layer:\n', self.fc) 80 | 81 | def forward(self, X): 82 | """ 83 | :param X: 84 | :return: 85 | """ 86 | N = X.size()[0] 87 | 88 | X = self.features(X) # extract features 89 | 90 | X = X.view(N, 512, 1 ** 2) 91 | X = torch.bmm(X, torch.transpose(X, 1, 2)) / (1 ** 2) # Bi-linear CNN 92 | 93 | X = X.view(N, 512 ** 2) 94 | X = torch.sqrt(X + 1e-5) 95 | X = torch.nn.functional.normalize(X) 96 | X = self.fc(X) 97 | assert X.size() == (N, self._num_cls) 98 | return X 99 | 100 | 101 | # ------------------------------------- vehicle detection model 102 | class Car_Classifier(object): 103 | """ 104 | vehicle detection model mabager 105 | """ 106 | 107 | def __init__(self, 108 | num_cls, 109 | model_path=local_model_path): 110 | """ 111 | load model and initialize 112 | """ 113 | 114 | # define model and load weights 115 | self.net = Cls_Net(num_cls=num_cls, input_size=224).to(device) 116 | # self.net = torch.nn.DataParallel(Net(num_cls=20, input_size=224), 117 | # device_ids=[0]).to(device) 118 | self.net.load_state_dict(torch.load(model_path)) 119 | print('=> vehicle classifier loaded from %s' % model_path) 120 | 121 | # set model to eval mode 122 | self.net.eval() 123 | 124 | # test data transforms 125 | self.transforms = torchvision.transforms.Compose([ 126 | torchvision.transforms.Resize(size=224), 127 | torchvision.transforms.CenterCrop(size=224), 128 | torchvision.transforms.ToTensor(), 129 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 130 | std=(0.229, 0.224, 0.225)) 131 | ]) 132 | 133 | # split each label 134 | self.color_attrs = color_attrs 135 | print('=> color_attrs:\n', self.color_attrs) 136 | 137 | self.direction_attrs = direction_attrs 138 | print('=> direction attrs:\n', self.direction_attrs) 139 | 140 | self.type_attrs = type_attrs 141 | print('=> type_attrs:\n', self.type_attrs) 142 | 143 | def get_predict(self, output): 144 | """ 145 | get prediction from output 146 | """ 147 | # get each label's prediction from output 148 | output = output.cpu() # fetch data from gpu 149 | pred_color = output[:, :9] 150 | pred_direction = output[:, 9:11] 151 | pred_type = output[:, 11:] 152 | 153 | color_idx = pred_color.max(1, keepdim=True)[1] 154 | direction_idx = pred_direction.max(1, keepdim=True)[1] 155 | type_idx = pred_type.max(1, keepdim=True)[1] 156 | pred = torch.cat((color_idx, direction_idx, type_idx), dim=1) 157 | return pred 158 | 159 | def pre_process(self, image): 160 | """ 161 | image formatting 162 | :rtype: PIL.JpegImagePlugin.JpegImageFile 163 | """ 164 | # image data formatting 165 | if type(image) == np.ndarray: 166 | if image.shape[2] == 3: # turn all 3 channels to RGB format 167 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 168 | elif image.shape[2] == 1: # turn 1 channel to RGB 169 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 170 | 171 | # turn numpy.ndarray into PIL.Image 172 | image = Image.fromarray(image) 173 | elif type(image) == PIL.JpegImagePlugin.JpegImageFile: 174 | if image.mode == 'L' or image.mode == 'I': # turn 8bits or 32bits into 3 channels RGB 175 | image = image.convert('RGB') 176 | 177 | return image 178 | 179 | def predict(self, img): 180 | """ 181 | predict vehicle attributes by classifying 182 | :return: vehicle color, direction and type 183 | """ 184 | # image pre-processing 185 | img = self.transforms(img) 186 | img = img.view(1, 3, 224, 224) 187 | 188 | # put image data into device 189 | img = img.to(device) 190 | 191 | # calculating inference 192 | output = self.net.forward(img) 193 | 194 | # get result 195 | # self.get_predict_ce, return pred to host side(cpu) 196 | pred = self.get_predict(output) 197 | color_name = self.color_attrs[pred[0][0]] 198 | direction_name = self.direction_attrs[pred[0][1]] 199 | type_name = self.type_attrs[pred[0][2]] 200 | 201 | return color_name, direction_name, type_name 202 | 203 | 204 | class Car_DC(): 205 | def __init__(self, 206 | src_dir, 207 | dst_dir, 208 | car_cfg_path=local_car_cfg_path, 209 | car_det_weights_path=local_car_det_weights_path, 210 | inp_dim=768, 211 | prob_th=0.2, 212 | nms_th=0.4, 213 | num_classes=1): 214 | """ 215 | model initialization 216 | """ 217 | # super parameters 218 | self.inp_dim = inp_dim 219 | self.prob_th = prob_th 220 | self.nms_th = nms_th 221 | self.num_classes = num_classes 222 | self.dst_dir = dst_dir 223 | 224 | # clear dst_dir 225 | if os.path.exists(self.dst_dir): 226 | for x in os.listdir(self.dst_dir): 227 | if x.endswith('.jpg'): 228 | os.remove(self.dst_dir + '/' + x) 229 | else: 230 | os.makedirs(self.dst_dir) 231 | 232 | # initialize vehicle detection model 233 | self.detector = Darknet(car_cfg_path) 234 | self.detector.load_weights(car_det_weights_path) 235 | # set input dimension of image 236 | self.detector.net_info['height'] = self.inp_dim 237 | self.detector.to(device) 238 | self.detector.eval() # evaluation mode 239 | print('=> car detection model initiated.') 240 | 241 | # initiate multilabel classifier 242 | self.classifier = Car_Classifier(num_cls=19, 243 | model_path=local_model_path) 244 | 245 | # initiate imgs_path 246 | self.imgs_path = [os.path.join(src_dir, x) for x in os.listdir( 247 | src_dir) if x.endswith('.jpg')] 248 | 249 | def cls_draw_bbox(self, output, orig_img): 250 | """ 251 | 1. predict vehicle's attributes based on bbox of vehicle 252 | 2. draw bbox to orig_img 253 | """ 254 | labels = [] 255 | pt_1s = [] 256 | pt_2s = [] 257 | 258 | # 1 259 | for det in output: 260 | # rectangle points 261 | pt_1 = tuple(det[1:3].int()) # the left-up point 262 | pt_2 = tuple(det[3:5].int()) # the right down point 263 | pt_1s.append(pt_1) 264 | pt_2s.append(pt_2) 265 | 266 | # turn BGR back to RGB 267 | ROI = Image.fromarray( 268 | orig_img[pt_1[1]: pt_2[1], 269 | pt_1[0]: pt_2[0]][:, :, ::-1]) 270 | # ROI.show() 271 | 272 | # call classifier to predict 273 | car_color, car_direction, car_type = self.classifier.predict(ROI) 274 | label = str(car_color + ' ' + car_direction + ' ' + car_type) 275 | labels.append(label) 276 | print('=> predicted label: ', label) 277 | 278 | # 2 279 | color = (0, 215, 255) 280 | for i, det in enumerate(output): 281 | pt_1 = pt_1s[i] 282 | pt_2 = pt_2s[i] 283 | 284 | # draw bounding box 285 | cv2.rectangle(orig_img, pt_1, pt_2, color, thickness=2) 286 | 287 | # get str text size 288 | txt_size = cv2.getTextSize( 289 | label, cv2.FONT_HERSHEY_PLAIN, 2, 2)[0] 290 | # pt_2 = pt_1[0] + txt_size[0] + 3, pt_1[1] + txt_size[1] + 5 291 | pt_2 = pt_1[0] + txt_size[0] + 3, pt_1[1] - txt_size[1] - 5 292 | 293 | # draw text background rect 294 | cv2.rectangle(orig_img, pt_1, pt_2, color, thickness=-1) # text 295 | 296 | # draw text 297 | cv2.putText(orig_img, labels[i], (pt_1[0], pt_1[1]), # pt_1[1] + txt_size[1] + 4 298 | cv2.FONT_HERSHEY_PLAIN, 2, [225, 255, 255], 2) 299 | 300 | def process_predict(self, 301 | prediction, 302 | prob_th, 303 | num_cls, 304 | nms_th, 305 | inp_dim, 306 | orig_img_size): 307 | """ 308 | processing detections 309 | """ 310 | scaling_factor = min([inp_dim / float(x) 311 | for x in orig_img_size]) # W, H scaling factor 312 | output = post_process(prediction, 313 | prob_th, 314 | num_cls, 315 | nms=True, 316 | nms_conf=nms_th, 317 | CUDA=True) # post-process such as nms 318 | 319 | if type(output) != int: 320 | output[:, [1, 3]] -= (inp_dim - scaling_factor * 321 | orig_img_size[0]) / 2.0 # x, w 322 | output[:, [2, 4]] -= (inp_dim - scaling_factor * 323 | orig_img_size[1]) / 2.0 # y, h 324 | output[:, 1:5] /= scaling_factor 325 | for i in range(output.shape[0]): 326 | output[i, [1, 3]] = torch.clamp( 327 | output[i, [1, 3]], 0.0, orig_img_size[0]) 328 | output[i, [2, 4]] = torch.clamp( 329 | output[i, [2, 4]], 0.0, orig_img_size[1]) 330 | return output 331 | 332 | def detect_classify(self): 333 | """ 334 | detect and classify 335 | """ 336 | for x in self.imgs_path: 337 | # read image data 338 | img = Image.open(x) 339 | img2det = process_img(img, self.inp_dim) 340 | img2det = img2det.to(device) # put image data to device 341 | 342 | # vehicle detection 343 | prediction = self.detector.forward(img2det, CUDA=True) 344 | 345 | # calculating scaling factor 346 | orig_img_size = list(img.size) 347 | output = self.process_predict(prediction, 348 | self.prob_th, 349 | self.num_classes, 350 | self.nms_th, 351 | self.inp_dim, 352 | orig_img_size) 353 | 354 | orig_img = cv2.cvtColor(np.asarray( 355 | img), cv2.COLOR_RGB2BGR) # RGB => BGR 356 | if type(output) != int: 357 | self.cls_draw_bbox(output, orig_img) 358 | dst_path = self.dst_dir + '/' + os.path.split(x)[1] 359 | if not os.path.exists(dst_path): 360 | cv2.imwrite(dst_path, orig_img) 361 | 362 | # ----------------------------------------------------------- 363 | 364 | 365 | parser = argparse.ArgumentParser(description='Detect and classify cars.') 366 | parser.add_argument('-src-dir', 367 | type=str, 368 | default='./test_imgs', 369 | help='source directory of images') 370 | parser.add_argument('-dst-dir', 371 | type=str, 372 | default='./test_result', 373 | help='destination directory of images to store results.') 374 | 375 | if __name__ == '__main__': 376 | # ---------------------------- Car detect and classify 377 | # DR_model = Car_DC(src_dir='./test_imgs', 378 | # dst_dir='./test_result') 379 | # DR_model.detect_classify() 380 | 381 | args = parser.parse_args() 382 | DR_model = Car_DC(src_dir=args.src_dir, dst_dir=args.dst_dir) 383 | DR_model.detect_classify() 384 | -------------------------------------------------------------------------------- /bbox.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import random 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | def confidence_filter(result, confidence): 10 | conf_mask = (result[:,:,4] > confidence).float().unsqueeze(2) 11 | result = result*conf_mask 12 | 13 | return result 14 | 15 | def confidence_filter_cls(result, confidence): 16 | max_scores = torch.max(result[:,:,5:25], 2)[0] 17 | res = torch.cat((result, max_scores),2) 18 | print(res.shape) 19 | 20 | 21 | cond_1 = (res[:,:,4] > confidence).float() 22 | cond_2 = (res[:,:,25] > 0.995).float() 23 | 24 | conf = cond_1 + cond_2 25 | conf = torch.clamp(conf, 0.0, 1.0) 26 | conf = conf.unsqueeze(2) 27 | result = result*conf 28 | return result 29 | 30 | 31 | 32 | def get_abs_coord(box): 33 | box[2], box[3] = abs(box[2]), abs(box[3]) 34 | x1 = (box[0] - box[2]/2) - 1 35 | y1 = (box[1] - box[3]/2) - 1 36 | x2 = (box[0] + box[2]/2) - 1 37 | y2 = (box[1] + box[3]/2) - 1 38 | return x1, y1, x2, y2 39 | 40 | 41 | 42 | def sanity_fix(box): 43 | if (box[0] > box[2]): 44 | box[0], box[2] = box[2], box[0] 45 | 46 | if (box[1] > box[3]): 47 | box[1], box[3] = box[3], box[1] 48 | 49 | return box 50 | 51 | def bbox_iou(box1, box2, CUDA=True): 52 | """ 53 | Returns the IoU of two bounding boxes 54 | 55 | 56 | """ 57 | #Get the coordinates of bounding boxes 58 | b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3] 59 | b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3] 60 | 61 | #get the corrdinates of the intersection rectangle 62 | inter_rect_x1 = torch.max(b1_x1, b2_x1) 63 | inter_rect_y1 = torch.max(b1_y1, b2_y1) 64 | inter_rect_x2 = torch.min(b1_x2, b2_x2) 65 | inter_rect_y2 = torch.min(b1_y2, b2_y2) 66 | 67 | #Intersection area 68 | if CUDA: # torch.cuda.is_available(): 69 | inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1,torch.zeros(inter_rect_x2.shape).cuda())*torch.max(inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape).cuda()) 70 | else: 71 | inter_area = torch.max(inter_rect_x2 - inter_rect_x1 + 1,torch.zeros(inter_rect_x2.shape))*torch.max(inter_rect_y2 - inter_rect_y1 + 1, torch.zeros(inter_rect_x2.shape)) 72 | 73 | #Union Area 74 | b1_area = (b1_x2 - b1_x1 + 1)*(b1_y2 - b1_y1 + 1) 75 | b2_area = (b2_x2 - b2_x1 + 1)*(b2_y2 - b2_y1 + 1) 76 | 77 | iou = inter_area / (b1_area + b2_area - inter_area) 78 | 79 | return iou 80 | 81 | 82 | def pred_corner_coord(prediction): 83 | #Get indices of non-zero confidence bboxes 84 | ind_nz = torch.nonzero(prediction[:,:,4]).transpose(0,1).contiguous() 85 | 86 | box = prediction[ind_nz[0], ind_nz[1]] 87 | 88 | 89 | box_a = box.new(box.shape) 90 | box_a[:,0] = (box[:,0] - box[:,2]/2) 91 | box_a[:,1] = (box[:,1] - box[:,3]/2) 92 | box_a[:,2] = (box[:,0] + box[:,2]/2) 93 | box_a[:,3] = (box[:,1] + box[:,3]/2) 94 | box[:,:4] = box_a[:,:4] 95 | 96 | prediction[ind_nz[0], ind_nz[1]] = box 97 | 98 | return prediction 99 | 100 | 101 | 102 | 103 | def write(x, batches, results, colors, classes): 104 | c1 = tuple(x[1:3].int()) 105 | c2 = tuple(x[3:5].int()) 106 | img = results[int(x[0])] 107 | cls = int(x[-1]) 108 | label = "{0}".format(classes[cls]) 109 | color = random.choice(colors) 110 | cv2.rectangle(img, c1, c2,color, 1) 111 | t_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_PLAIN, 1 , 1)[0] 112 | c2 = c1[0] + t_size[0] + 3, c1[1] + t_size[1] + 4 113 | cv2.rectangle(img, c1, c2,color, -1) 114 | cv2.putText(img, label, (c1[0], c1[1] + t_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1, [225,255,255], 1); 115 | return img 116 | -------------------------------------------------------------------------------- /car.cfg: -------------------------------------------------------------------------------- 1 | [net] 2 | # Testing 3 | batch=1 4 | subdivisions=1 5 | # Training 6 | #batch=128 7 | #subdivisions=4 8 | width = 768 9 | height = 448 10 | channels=3 11 | momentum=0.9 12 | decay=0.0005 13 | angle=0 14 | saturation = 1.5 15 | exposure = 1.5 16 | hue=.1 17 | 18 | learning_rate=0.0005 19 | burn_in=1000 20 | max_batches = 500200 21 | policy=steps 22 | steps=100, 1000, 5000, 10000, 50000, 60000, 80000, 120000, 140000, 162000 23 | scales=0.1, 1, 5, 2, 1, 0.5, 0.2, 0.1, 0.05, 0.02 24 | 25 | [convolutional] 26 | batch_normalize=1 27 | filters=16 28 | size=3 29 | stride=1 30 | pad=1 31 | activation=leaky 32 | 33 | [maxpool] 34 | size=2 35 | stride=2 36 | 37 | [convolutional] 38 | batch_normalize=1 39 | filters=32 40 | size=3 41 | stride=1 42 | pad=1 43 | activation=leaky 44 | 45 | [maxpool] 46 | size=2 47 | stride=2 48 | 49 | [convolutional] 50 | batch_normalize=1 51 | filters=64 52 | size=3 53 | stride=1 54 | pad=1 55 | activation=leaky 56 | 57 | [maxpool] 58 | size=2 59 | stride=2 60 | 61 | [convolutional] 62 | batch_normalize=1 63 | filters=128 64 | size=3 65 | stride=1 66 | pad=1 67 | activation=leaky 68 | 69 | [maxpool] 70 | size=2 71 | stride=2 72 | 73 | [convolutional] 74 | batch_normalize=1 75 | filters=256 76 | size=3 77 | stride=1 78 | pad=1 79 | share=0 80 | activation=leaky 81 | 82 | [maxpool] 83 | size=2 84 | stride=2 85 | 86 | [convolutional] 87 | batch_normalize=1 88 | filters=512 89 | size=3 90 | stride=1 91 | pad=1 92 | activation=leaky 93 | 94 | [maxpool] 95 | size=2 96 | stride=1 97 | 98 | [convolutional] 99 | batch_normalize=1 100 | filters=1024 101 | size=3 102 | stride=1 103 | pad=1 104 | activation=leaky 105 | 106 | ########### 107 | 108 | [convolutional] 109 | batch_normalize=1 110 | filters=256 111 | size=1 112 | stride=1 113 | share=0 114 | pad=1 115 | activation=leaky 116 | 117 | [convolutional] 118 | batch_normalize=1 119 | filters=512 120 | size=3 121 | stride=1 122 | pad=1 123 | activation=leaky 124 | 125 | [convolutional] 126 | size=1 127 | stride=1 128 | pad=1 129 | filters=18 130 | activation=linear 131 | 132 | 133 | 134 | [yolo] 135 | mask = 3,4,5 136 | anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319 137 | classes=1 138 | num=6 139 | jitter=.3 140 | ignore_thresh = .7 141 | truth_thresh = 1 142 | random=1 143 | 144 | [route] 145 | layers = -4 146 | 147 | [convolutional] 148 | batch_normalize=1 149 | filters=128 150 | size=1 151 | stride=1 152 | pad=1 153 | activation=leaky 154 | 155 | [upsample] 156 | stride=2 157 | share=0 158 | 159 | [route] 160 | layers = -1, 8 161 | 162 | [convolutional] 163 | batch_normalize=1 164 | filters=256 165 | size=3 166 | stride=1 167 | pad=1 168 | activation=leaky 169 | 170 | [convolutional] 171 | size=1 172 | stride=1 173 | pad=1 174 | filters=18 175 | activation=linear 176 | 177 | [yolo] 178 | mask = 1,2,3 179 | anchors = 10,14, 23,27, 37,58, 81,82, 135,169, 344,319 180 | classes=1 181 | num=6 182 | jitter=.3 183 | ignore_thresh = .7 184 | truth_thresh = 1 185 | random=1 186 | -------------------------------------------------------------------------------- /car.data: -------------------------------------------------------------------------------- 1 | classes = 1 2 | train = /mnt/diskc/even/Car/train.txt 3 | valid = /mnt/diskc/even/Car/valid.txt 4 | names = cfg/car.names 5 | backup = backup_car 6 | 7 | -------------------------------------------------------------------------------- /car.names: -------------------------------------------------------------------------------- 1 | car 2 | -------------------------------------------------------------------------------- /checkpoints/test: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /darknet.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | from darknet_util import count_parameters as count 11 | from darknet_util import convert2cpu as cpu 12 | from darknet_util import predict_transform 13 | 14 | 15 | class test_net(nn.Module): 16 | def __init__(self, num_layers, input_size): 17 | super(test_net, self).__init__() 18 | self.num_layers = num_layers 19 | self.linear_1 = nn.Linear(input_size, 5) 20 | self.middle = nn.ModuleList([nn.Linear(5, 5) for x in range(num_layers)]) 21 | self.output = nn.Linear(5, 2) 22 | 23 | def forward(self, x): 24 | x = x.view(-1) 25 | fwd = nn.Sequential(self.linear_1, *self.middle, self.output) 26 | return fwd(x) 27 | 28 | 29 | def get_test_input(): 30 | img = cv2.imread("dog-cycle-car.png") 31 | img = cv2.resize(img, (416, 416)) 32 | img_ = img[:, :, ::-1].transpose((2, 0, 1)) 33 | img_ = img_[np.newaxis, :, :, :] / 255.0 34 | img_ = torch.from_numpy(img_).float() 35 | img_ = Variable(img_) 36 | return img_ 37 | 38 | 39 | def parse_cfg(cfgfile): 40 | """ 41 | Takes a configuration file 42 | 43 | Returns a list of blocks. Each blocks describes a block in the neural 44 | network to be built. Block is represented as a dictionary in the list 45 | 46 | """ 47 | file = open(cfgfile, 'r') 48 | lines = file.read().split('\n') # store the lines in a list 49 | lines = [x for x in lines if len(x) > 0] # get rid of the empty lines 50 | lines = [x for x in lines if x[0] != '#'] # get rid of commented lines 51 | lines = [x.rstrip().lstrip() for x in lines] 52 | 53 | block = {} # 一个block即一个层 54 | blocks = [] 55 | 56 | for line in lines: 57 | if line[0] == '[': # This marks the start of a new block 58 | if len(block) != 0: 59 | blocks.append(block) # 将已经解析的层放入容器 60 | block = {} 61 | block['type'] = line[1:-1].rstrip() # 层类型 62 | else: 63 | key, value = line.split('=') 64 | block[key.rstrip()] = value.lstrip() 65 | blocks.append(block) 66 | 67 | return blocks 68 | 69 | 70 | # print('\n\n'.join([repr(x) for x in blocks])) 71 | 72 | import pickle as pkl 73 | 74 | 75 | class MaxPoolStride1(nn.Module): 76 | def __init__(self, kernel_size): 77 | super(MaxPoolStride1, self).__init__() 78 | self.kernel_size = kernel_size 79 | self.pad = kernel_size - 1 80 | 81 | def forward(self, x): 82 | padded_x = F.pad(x, (0, self.pad, 0, self.pad), mode="replicate") 83 | pooled_x = nn.MaxPool2d(self.kernel_size, self.pad)(padded_x) 84 | return pooled_x 85 | 86 | 87 | class EmptyLayer(nn.Module): 88 | def __init__(self): 89 | super(EmptyLayer, self).__init__() 90 | 91 | 92 | class DetectionLayer(nn.Module): 93 | def __init__(self, anchors): 94 | super(DetectionLayer, self).__init__() 95 | self.anchors = anchors 96 | 97 | def forward(self, x, inp_dim, num_classes, confidence): 98 | x = x.data 99 | global CUDA 100 | prediction = x 101 | prediction = predict_transform(prediction, inp_dim, self.anchors, num_classes, confidence, CUDA) 102 | return prediction 103 | 104 | 105 | class Upsample(nn.Module): 106 | def __init__(self, stride=2): 107 | super(Upsample, self).__init__() 108 | self.stride = stride 109 | 110 | def forward(self, x): 111 | stride = self.stride 112 | assert (x.data.dim() == 4) 113 | B = x.data.size(0) 114 | C = x.data.size(1) 115 | H = x.data.size(2) 116 | W = x.data.size(3) 117 | ws = stride 118 | hs = stride 119 | x = x.view(B, C, H, 1, W, 1).expand(B, C, H, stride, W, stride).contiguous().view(B, C, H * stride, W * stride) 120 | return x 121 | 122 | 123 | # 124 | 125 | class ReOrgLayer(nn.Module): 126 | def __init__(self, stride=2): 127 | super(ReOrgLayer, self).__init__() 128 | self.stride = stride 129 | 130 | def forward(self, x): 131 | assert (x.data.dim() == 4) 132 | B, C, H, W = x.data.shape 133 | hs = self.stride 134 | ws = self.stride 135 | assert (H % hs == 0), "The stride " + str(self.stride) + " is not a proper divisor of height " + str(H) 136 | assert (W % ws == 0), "The stride " + str(self.stride) + " is not a proper divisor of height " + str(W) 137 | x = x.view(B, C, H // hs, hs, W // ws, ws).transpose(-2, -3).contiguous() 138 | x = x.view(B, C, H // hs * W // ws, hs, ws) 139 | x = x.view(B, C, H // hs * W // ws, hs * ws).transpose(-1, -2).contiguous() 140 | x = x.view(B, C, ws * hs, H // ws, W // ws).transpose(1, 2).contiguous() 141 | x = x.view(B, C * ws * hs, H // ws, W // ws) 142 | return x 143 | 144 | 145 | def create_modules(blocks): 146 | net_info = blocks[0] # Captures the information about the input and pre-processing 147 | 148 | module_list = nn.ModuleList() 149 | 150 | index = 0 # indexing blocks helps with implementing route layers (skip connections) 151 | 152 | prev_filters = 3 # 初始出入3通道图像数据 153 | 154 | output_filters = [] 155 | 156 | for x in blocks: 157 | module = nn.Sequential() 158 | 159 | if x['type'] == 'net': 160 | continue 161 | 162 | # If it's a convolutional layer: conv layer包含conv layer batch norm和非线性激活 163 | if x['type'] == 'convolutional': 164 | # Get the info about the layer 165 | activation = x['activation'] 166 | try: 167 | batch_normalize = int(x['batch_normalize']) # 含有batch normalization就没有bias 168 | bias = False 169 | except: 170 | batch_normalize = 0 # 没有batch normalization就有bias 171 | bias = True 172 | 173 | filters = int(x['filters']) 174 | padding = int(x['pad']) 175 | kernel_size = int(x['size']) 176 | stride = int(x['stride']) 177 | 178 | if padding: 179 | pad = (kernel_size - 1) // 2 # 两边填充padding size 180 | else: 181 | pad = 0 182 | 183 | # Add the convolutional layer 184 | conv = nn.Conv2d(prev_filters, filters, kernel_size, stride, pad, bias=bias) 185 | module.add_module('conv_{0}'.format(index), conv) 186 | 187 | # Add the Batch Norm Layer 188 | if batch_normalize: # batch norm是属于conv layer的 189 | bn = nn.BatchNorm2d(filters) 190 | module.add_module('batch_norm_{0}'.format(index), bn) 191 | 192 | # Check the activation. 193 | # It is either Linear or a Leaky ReLU for YOLO 194 | if activation == 'leaky': # 非线性激活也属于conv layer 195 | activn = nn.LeakyReLU(0.1, inplace=True) 196 | module.add_module('leaky_{0}'.format(index), activn) 197 | 198 | # If it's an upsampling layer 199 | # We use Bilinear2dUpsampling 200 | 201 | elif x['type'] == 'upsample': 202 | stride = int(x['stride']) 203 | # upsample = Upsample(stride) 204 | upsample = nn.Upsample(scale_factor=2, mode='nearest') # 为什么使用最近邻插值, 而不用其他的? 205 | module.add_module('upsample_{}'.format(index), upsample) 206 | 207 | # If it is a route layer 208 | elif x['type'] == 'route': 209 | x['layers'] = x['layers'].split(',') 210 | 211 | # Start of a route 212 | start = int(x['layers'][0]) 213 | 214 | # end, if there exists one. 215 | try: 216 | end = int(x['layers'][1]) 217 | except: 218 | end = 0 219 | 220 | # Positive anotation 221 | if start > 0: 222 | start = start - index 223 | 224 | if end > 0: 225 | end = end - index 226 | 227 | route = EmptyLayer() 228 | module.add_module('route_{0}'.format(index), route) 229 | 230 | if end < 0: 231 | filters = output_filters[index + start] + output_filters[index + end] 232 | else: 233 | filters = output_filters[index + start] 234 | 235 | # shortcut corresponds to skip connection 236 | elif x["type"] == "shortcut": 237 | from_ = int(x["from"]) 238 | shortcut = EmptyLayer() 239 | module.add_module("shortcut_{}".format(index), shortcut) 240 | 241 | elif x["type"] == "maxpool": 242 | stride = int(x["stride"]) 243 | size = int(x["size"]) 244 | if stride != 1: 245 | maxpool = nn.MaxPool2d(size, stride) 246 | else: 247 | maxpool = MaxPoolStride1(size) 248 | 249 | module.add_module("maxpool_{}".format(index), maxpool) 250 | 251 | # Yolo is the detection layer 252 | elif x["type"] == "yolo": 253 | mask = x["mask"].split(",") 254 | mask = [int(x) for x in mask] 255 | 256 | anchors = x["anchors"].split(",") 257 | anchors = [int(a) for a in anchors] 258 | anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] 259 | anchors = [anchors[i] for i in mask] 260 | 261 | detection = DetectionLayer(anchors) 262 | module.add_module("Detection_{}".format(index), detection) 263 | 264 | else: 265 | print("Something I dunno") 266 | assert False 267 | 268 | module_list.append(module) 269 | prev_filters = filters # 270 | output_filters.append(filters) 271 | index += 1 # 更新index 272 | 273 | return (net_info, module_list) 274 | 275 | 276 | class Darknet(nn.Module): 277 | def __init__(self, cfgfile): 278 | super(Darknet, self).__init__() 279 | self.blocks = parse_cfg(cfgfile) 280 | self.net_info, self.module_list = create_modules(self.blocks) 281 | self.header = torch.IntTensor([0, 0, 0, 0]) 282 | self.seen = 0 283 | 284 | def get_blocks(self): 285 | return self.blocks 286 | 287 | def get_module_list(self): 288 | return self.module_list 289 | 290 | def forward(self, x, CUDA): 291 | detections = [] 292 | modules = self.blocks[1:] 293 | outputs = {} # We cache the outputs for the route layer 294 | 295 | write = 0 296 | for i in range(len(modules)): 297 | 298 | module_type = (modules[i]['type']) 299 | if module_type == 'convolutional' or module_type == 'upsample' or module_type == 'maxpool': 300 | 301 | x = self.module_list[i](x) 302 | outputs[i] = x 303 | 304 | elif module_type == 'route': 305 | layers = modules[i]['layers'] 306 | layers = [int(a) for a in layers] 307 | 308 | if (layers[0]) > 0: 309 | layers[0] = layers[0] - i 310 | 311 | if len(layers) == 1: 312 | x = outputs[i + (layers[0])] 313 | 314 | else: 315 | if (layers[1]) > 0: 316 | layers[1] = layers[1] - i 317 | 318 | map1 = outputs[i + layers[0]] 319 | map2 = outputs[i + layers[1]] 320 | 321 | x = torch.cat((map1, map2), 1) 322 | outputs[i] = x 323 | 324 | elif module_type == "shortcut": 325 | from_ = int(modules[i]["from"]) 326 | x = outputs[i - 1] + outputs[i + from_] 327 | outputs[i] = x 328 | 329 | elif module_type == 'yolo': 330 | anchors = self.module_list[i][0].anchors 331 | # Get the input dimensions 332 | inp_dim = int(self.net_info["height"]) 333 | 334 | # Get the number of classes 335 | num_classes = int(modules[i]["classes"]) 336 | 337 | # Output the result 338 | x = x.data 339 | x = predict_transform(x, inp_dim, anchors, num_classes, CUDA) 340 | 341 | if type(x) == int: 342 | continue 343 | 344 | if not write: 345 | detections = x 346 | write = 1 347 | 348 | else: 349 | detections = torch.cat((detections, x), 1) 350 | 351 | outputs[i] = outputs[i - 1] 352 | 353 | try: 354 | return detections 355 | except: 356 | return 0 357 | 358 | def load_weights(self, weight_file): 359 | # Open the weights file 360 | fp = open(weight_file, "rb") 361 | 362 | # The first 4 values are header information 363 | # 1. Major version number 364 | # 2. Minor Version Number 365 | # 3. Subversion number 366 | # 4. IMages seen 367 | header = np.fromfile(fp, dtype=np.int32, count=5) 368 | self.header = torch.from_numpy(header) 369 | self.seen = self.header[3] 370 | 371 | # The rest of the values are the weights 372 | # Let's load them up 373 | weights = np.fromfile(fp, dtype=np.float32) 374 | 375 | ptr = 0 376 | for i in range(len(self.module_list)): 377 | module_type = self.blocks[i + 1]["type"] 378 | 379 | if module_type == "convolutional": 380 | model = self.module_list[i] 381 | try: 382 | batch_normalize = int(self.blocks[i + 1]["batch_normalize"]) 383 | except: 384 | batch_normalize = 0 385 | 386 | conv = model[0] 387 | 388 | if (batch_normalize): 389 | bn = model[1] 390 | 391 | # Get the number of weights of Batch Norm Layer 392 | num_bn_biases = bn.bias.numel() 393 | 394 | # Load the weights 395 | bn_biases = torch.from_numpy(weights[ptr:ptr + num_bn_biases]) 396 | ptr += num_bn_biases 397 | 398 | bn_weights = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 399 | ptr += num_bn_biases 400 | 401 | bn_running_mean = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 402 | ptr += num_bn_biases 403 | 404 | bn_running_var = torch.from_numpy(weights[ptr: ptr + num_bn_biases]) 405 | ptr += num_bn_biases 406 | 407 | # Cast the loaded weights into dims of net weights. 408 | bn_biases = bn_biases.view_as(bn.bias.data) 409 | bn_weights = bn_weights.view_as(bn.weight.data) 410 | bn_running_mean = bn_running_mean.view_as(bn.running_mean) 411 | bn_running_var = bn_running_var.view_as(bn.running_var) 412 | 413 | # Copy the data to net 414 | bn.bias.data.copy_(bn_biases) 415 | bn.weight.data.copy_(bn_weights) 416 | bn.running_mean.copy_(bn_running_mean) 417 | bn.running_var.copy_(bn_running_var) 418 | 419 | else: 420 | # Number of biases 421 | num_biases = conv.bias.numel() 422 | 423 | # Load the weights 424 | conv_biases = torch.from_numpy(weights[ptr: ptr + num_biases]) 425 | ptr = ptr + num_biases 426 | 427 | # reshape the loaded weights according to the dims of the net weights 428 | conv_biases = conv_biases.view_as(conv.bias.data) 429 | 430 | # Finally copy the data 431 | conv.bias.data.copy_(conv_biases) 432 | 433 | # Let us load the weights for the Convolutional layers 434 | num_weights = conv.weight.numel() 435 | 436 | # Do the same as above for weights 437 | conv_weights = torch.from_numpy(weights[ptr:ptr + num_weights]) 438 | ptr = ptr + num_weights 439 | 440 | conv_weights = conv_weights.view_as(conv.weight.data) 441 | conv.weight.data.copy_(conv_weights) 442 | print('=> %s loaded.' % weight_file) 443 | 444 | def save_weights(self, saved_file, cutoff=0): 445 | 446 | if cutoff <= 0: 447 | cutoff = len(self.blocks) - 1 448 | 449 | fp = open(saved_file, 'wb') 450 | 451 | # Attach the header at the top of the file 452 | self.header[3] = self.seen 453 | header = self.header 454 | 455 | header = header.numpy() 456 | header.tofile(fp) 457 | 458 | # Now, let us save the weights 459 | for i in range(len(self.module_list)): 460 | module_type = self.blocks[i + 1]["type"] 461 | 462 | if (module_type) == "convolutional": 463 | model = self.module_list[i] 464 | try: 465 | batch_normalize = int(self.blocks[i + 1]["batch_normalize"]) 466 | except: 467 | batch_normalize = 0 468 | 469 | conv = model[0] 470 | 471 | if (batch_normalize): 472 | bn = model[1] 473 | 474 | # If the parameters are on GPU, convert them back to CPU 475 | # We don't convert the parameter to GPU 476 | # Instead. we copy the parameter and then convert it to CPU 477 | # This is done as weight are need to be saved during training 478 | cpu(bn.bias.data).numpy().tofile(fp) 479 | cpu(bn.weight.data).numpy().tofile(fp) 480 | cpu(bn.running_mean).numpy().tofile(fp) 481 | cpu(bn.running_var).numpy().tofile(fp) 482 | 483 | 484 | else: 485 | cpu(conv.bias.data).numpy().tofile(fp) 486 | 487 | # Let us save the weights for the Convolutional layers 488 | cpu(conv.weight.data).numpy().tofile(fp) 489 | 490 | # 491 | # dn = Darknet('cfg/yolov3.cfg') 492 | # dn.load_weights("yolov3.weights") 493 | # inp = get_test_input() 494 | # a, interms = dn(inp) 495 | # dn.eval() 496 | # a_i, interms_i = dn(inp) 497 | -------------------------------------------------------------------------------- /darknet_util.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | import numpy as np 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | from bbox import bbox_iou 13 | 14 | 15 | def count_parameters(model): 16 | return sum(p.numel() for p in model.parameters()) 17 | 18 | 19 | def count_learnable_parameters(model): 20 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | 22 | 23 | def convert2cpu(matrix): 24 | if matrix.is_cuda: 25 | return torch.FloatTensor(matrix.size()).copy_(matrix) 26 | else: 27 | return matrix 28 | 29 | 30 | def predict_transform(prediction, inp_dim, anchors, num_classes, CUDA=True): 31 | batch_size = prediction.size(0) 32 | stride = inp_dim // prediction.size(2) 33 | grid_size = inp_dim // stride 34 | bbox_attrs = 5 + num_classes 35 | num_anchors = len(anchors) 36 | 37 | anchors = [(a[0] / stride, a[1] / stride) for a in anchors] 38 | 39 | prediction = prediction.view( 40 | batch_size, bbox_attrs * num_anchors, grid_size * grid_size) 41 | prediction = prediction.transpose(1, 2).contiguous() 42 | prediction = prediction.view( 43 | batch_size, grid_size * grid_size * num_anchors, bbox_attrs) 44 | 45 | # Sigmoid the centre_X, centre_Y. and object confidencce 46 | prediction[:, :, 0] = torch.sigmoid(prediction[:, :, 0]) 47 | prediction[:, :, 1] = torch.sigmoid(prediction[:, :, 1]) 48 | prediction[:, :, 4] = torch.sigmoid(prediction[:, :, 4]) 49 | 50 | # Add the center offsets 51 | grid_len = np.arange(grid_size) 52 | a, b = np.meshgrid(grid_len, grid_len) 53 | 54 | x_offset = torch.FloatTensor(a).view(-1, 1) 55 | y_offset = torch.FloatTensor(b).view(-1, 1) 56 | 57 | if CUDA: 58 | x_offset = x_offset.cuda() 59 | y_offset = y_offset.cuda() 60 | 61 | x_y_offset = torch.cat((x_offset, y_offset), 1).repeat( 62 | 1, num_anchors).view(-1, 2).unsqueeze(0) 63 | 64 | prediction[:, :, :2] += x_y_offset 65 | 66 | # log space transform height and the width 67 | anchors = torch.FloatTensor(anchors) 68 | 69 | if CUDA: 70 | anchors = anchors.cuda() 71 | 72 | anchors = anchors.repeat(grid_size * grid_size, 1).unsqueeze(0) 73 | prediction[:, :, 2:4] = torch.exp(prediction[:, :, 2:4]) * anchors 74 | 75 | # Softmax the class scores 76 | prediction[:, :, 5: 5 + 77 | num_classes] = torch.sigmoid((prediction[:, :, 5: 5 + num_classes])) 78 | 79 | prediction[:, :, :4] *= stride 80 | 81 | return prediction 82 | 83 | 84 | def load_classes(namesfile): 85 | fp = open(namesfile, "r") 86 | names = fp.read().split("\n")[:-1] 87 | return names 88 | 89 | 90 | def get_im_dim(im): 91 | im = cv2.imread(im) 92 | w, h = im.shape[1], im.shape[0] 93 | return w, h 94 | 95 | 96 | def unique(tensor): 97 | tensor_np = tensor.cpu().numpy() 98 | unique_np = np.unique(tensor_np) 99 | unique_tensor = torch.from_numpy(unique_np) 100 | 101 | tensor_res = tensor.new(unique_tensor.shape) 102 | tensor_res.copy_(unique_tensor) 103 | return tensor_res 104 | 105 | 106 | def post_process(prediction, 107 | confidence, 108 | num_classes, 109 | nms=True, 110 | nms_conf=0.4, 111 | CUDA=True): 112 | conf_mask = (prediction[:, :, 4] > confidence).float().unsqueeze(2) 113 | prediction = prediction * conf_mask 114 | 115 | try: 116 | ind_nz = torch.nonzero( 117 | prediction[:, :, 4]).transpose(0, 1).contiguous() 118 | except: 119 | return 0 120 | 121 | box_a = prediction.new(prediction.shape) 122 | box_a[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2) 123 | box_a[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2) 124 | box_a[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2) 125 | box_a[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2) 126 | prediction[:, :, :4] = box_a[:, :, :4] 127 | 128 | batch_size = prediction.size(0) 129 | 130 | output = prediction.new(1, prediction.size(2) + 1) 131 | write = False 132 | 133 | for ind in range(batch_size): 134 | # select the image from the batch 135 | image_pred = prediction[ind] 136 | 137 | # Get the class having maximum score, and the index of that class 138 | # Get rid of num_classes softmax scores 139 | # Add the class index and the class score of class having maximum score 140 | max_conf, max_conf_score = torch.max( 141 | image_pred[:, 5:5 + num_classes], 1) 142 | max_conf = max_conf.float().unsqueeze(1) 143 | max_conf_score = max_conf_score.float().unsqueeze(1) 144 | seq = (image_pred[:, :5], max_conf, max_conf_score) 145 | image_pred = torch.cat(seq, 1) 146 | 147 | # Get rid of the zero entries 148 | non_zero_ind = (torch.nonzero(image_pred[:, 4])) 149 | 150 | image_pred_ = image_pred[non_zero_ind.squeeze(), :].view(-1, 7) 151 | 152 | # Get the various classes detected in the image 153 | try: 154 | img_classes = unique(image_pred_[:, -1]) 155 | except: 156 | continue 157 | # WE will do NMS classwise 158 | for cls in img_classes: 159 | # get the detections with one particular class 160 | cls_mask = image_pred_ * \ 161 | (image_pred_[:, -1] == cls).float().unsqueeze(1) 162 | class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze() 163 | 164 | image_pred_class = image_pred_[class_mask_ind].view(-1, 7) 165 | 166 | # sort the detections such that the entry with the maximum objectness 167 | # confidence is at the top 168 | conf_sort_index = torch.sort( 169 | image_pred_class[:, 4], descending=True)[1] 170 | image_pred_class = image_pred_class[conf_sort_index] 171 | idx = image_pred_class.size(0) 172 | 173 | # if nms has to be done 174 | if nms: 175 | # For each detection 176 | for i in range(idx): 177 | # Get the IOUs of all boxes that come after the one we are looking at 178 | # in the loop 179 | try: 180 | ious = bbox_iou(image_pred_class[i].unsqueeze(0), 181 | image_pred_class[i + 1:], 182 | CUDA=CUDA) 183 | except ValueError: 184 | break 185 | 186 | except IndexError: 187 | break 188 | 189 | # Zero out all the detections that have IoU > treshhold 190 | iou_mask = (ious < nms_conf).float().unsqueeze(1) 191 | image_pred_class[i + 1:] *= iou_mask 192 | 193 | # Remove the non-zero entries 194 | non_zero_ind = torch.nonzero( 195 | image_pred_class[:, 4]).squeeze() 196 | image_pred_class = image_pred_class[non_zero_ind].view( 197 | -1, 7) 198 | 199 | # Concatenate the batch_id of the image to the detection 200 | # this helps us identify which image does the detection correspond to 201 | # We use a linear straucture to hold ALL the detections from the batch 202 | # the batch_dim is flattened 203 | # batch is identified by extra batch column 204 | 205 | batch_ind = image_pred_class.new( 206 | image_pred_class.size(0), 1).fill_(ind) 207 | seq = batch_ind, image_pred_class 208 | if not write: 209 | output = torch.cat(seq, 1) 210 | write = True 211 | else: 212 | out = torch.cat(seq, 1) 213 | output = torch.cat((output, out)) 214 | 215 | return output 216 | 217 | 218 | # !/usr/bin/env python3 219 | # -*- coding: utf-8 -*- 220 | """ 221 | Created on Sat Mar 24 00:12:16 2018 222 | 223 | @author: ayooshmac 224 | """ 225 | 226 | 227 | def predict_transform_half(prediction, inp_dim, anchors, num_classes, CUDA=True): 228 | batch_size = prediction.size(0) 229 | stride = inp_dim // prediction.size(2) 230 | 231 | bbox_attrs = 5 + num_classes 232 | num_anchors = len(anchors) 233 | grid_size = inp_dim // stride 234 | 235 | prediction = prediction.view( 236 | batch_size, bbox_attrs * num_anchors, grid_size * grid_size) 237 | prediction = prediction.transpose(1, 2).contiguous() 238 | prediction = prediction.view( 239 | batch_size, grid_size * grid_size * num_anchors, bbox_attrs) 240 | 241 | # Sigmoid the centre_X, centre_Y. and object confidencce 242 | prediction[:, :, 0] = torch.sigmoid(prediction[:, :, 0]) 243 | prediction[:, :, 1] = torch.sigmoid(prediction[:, :, 1]) 244 | prediction[:, :, 4] = torch.sigmoid(prediction[:, :, 4]) 245 | 246 | # Add the center offsets 247 | grid_len = np.arange(grid_size) 248 | a, b = np.meshgrid(grid_len, grid_len) 249 | 250 | x_offset = torch.FloatTensor(a).view(-1, 1) 251 | y_offset = torch.FloatTensor(b).view(-1, 1) 252 | 253 | if CUDA: 254 | x_offset = x_offset.cuda().half() 255 | y_offset = y_offset.cuda().half() 256 | 257 | x_y_offset = torch.cat((x_offset, y_offset), 1).repeat( 258 | 1, num_anchors).view(-1, 2).unsqueeze(0) 259 | 260 | prediction[:, :, :2] += x_y_offset 261 | 262 | # log space transform height and the width 263 | anchors = torch.HalfTensor(anchors) 264 | 265 | if CUDA: 266 | anchors = anchors.cuda() 267 | 268 | anchors = anchors.repeat(grid_size * grid_size, 1).unsqueeze(0) 269 | prediction[:, :, 2:4] = torch.exp(prediction[:, :, 2:4]) * anchors 270 | 271 | # Softmax the class scores 272 | prediction[:, :, 5: 5 + num_classes] = nn.Softmax(-1)( 273 | Variable(prediction[:, :, 5: 5 + num_classes])).data 274 | 275 | prediction[:, :, :4] *= stride 276 | 277 | return prediction 278 | 279 | 280 | def write_results_half(prediction, confidence, num_classes, nms=True, nms_conf=0.4): 281 | conf_mask = (prediction[:, :, 4] > confidence).half().unsqueeze(2) 282 | prediction = prediction * conf_mask 283 | 284 | try: 285 | ind_nz = torch.nonzero( 286 | prediction[:, :, 4]).transpose(0, 1).contiguous() 287 | except: 288 | return 0 289 | 290 | box_a = prediction.new(prediction.shape) 291 | box_a[:, :, 0] = (prediction[:, :, 0] - prediction[:, :, 2] / 2) 292 | box_a[:, :, 1] = (prediction[:, :, 1] - prediction[:, :, 3] / 2) 293 | box_a[:, :, 2] = (prediction[:, :, 0] + prediction[:, :, 2] / 2) 294 | box_a[:, :, 3] = (prediction[:, :, 1] + prediction[:, :, 3] / 2) 295 | prediction[:, :, :4] = box_a[:, :, :4] 296 | 297 | batch_size = prediction.size(0) 298 | 299 | output = prediction.new(1, prediction.size(2) + 1) 300 | write = False 301 | 302 | for ind in range(batch_size): 303 | # select the image from the batch 304 | image_pred = prediction[ind] 305 | 306 | # Get the class having maximum score, and the index of that class 307 | # Get rid of num_classes softmax scores 308 | # Add the class index and the class score of class having maximum score 309 | max_conf, max_conf_score = torch.max( 310 | image_pred[:, 5:5 + num_classes], 1) 311 | max_conf = max_conf.half().unsqueeze(1) 312 | max_conf_score = max_conf_score.half().unsqueeze(1) 313 | seq = (image_pred[:, :5], max_conf, max_conf_score) 314 | image_pred = torch.cat(seq, 1) 315 | 316 | # Get rid of the zero entries 317 | non_zero_ind = (torch.nonzero(image_pred[:, 4])) 318 | try: 319 | image_pred_ = image_pred[non_zero_ind.squeeze(), :] 320 | except: 321 | continue 322 | 323 | # Get the various classes detected in the image 324 | img_classes = unique(image_pred_[:, -1].long()).half() 325 | 326 | # WE will do NMS classwise 327 | for cls in img_classes: 328 | # get the detections with one particular class 329 | cls_mask = image_pred_ * \ 330 | (image_pred_[:, -1] == cls).half().unsqueeze(1) 331 | class_mask_ind = torch.nonzero(cls_mask[:, -2]).squeeze() 332 | 333 | image_pred_class = image_pred_[class_mask_ind] 334 | 335 | # sort the detections such that the entry with the maximum objectness 336 | # confidence is at the top 337 | conf_sort_index = torch.sort( 338 | image_pred_class[:, 4], descending=True)[1] 339 | image_pred_class = image_pred_class[conf_sort_index] 340 | idx = image_pred_class.size(0) 341 | 342 | # if nms has to be done 343 | if nms: 344 | # For each detection 345 | for i in range(idx): 346 | # Get the IOUs of all boxes that come after the one we are looking at 347 | # in the loop 348 | try: 349 | ious = bbox_iou(image_pred_class[i].unsqueeze( 350 | 0), image_pred_class[i + 1:]) 351 | except ValueError: 352 | break 353 | 354 | except IndexError: 355 | break 356 | 357 | # Zero out all the detections that have IoU > treshhold 358 | iou_mask = (ious < nms_conf).half().unsqueeze(1) 359 | image_pred_class[i + 1:] *= iou_mask 360 | 361 | # Remove the non-zero entries 362 | non_zero_ind = torch.nonzero( 363 | image_pred_class[:, 4]).squeeze() 364 | image_pred_class = image_pred_class[non_zero_ind] 365 | 366 | # Concatenate the batch_id of the image to the detection 367 | # this helps us identify which image does the detection correspond to 368 | # We use a linear straucture to hold ALL the detections from the batch 369 | # the batch_dim is flattened 370 | # batch is identified by extra batch column 371 | batch_ind = image_pred_class.new( 372 | image_pred_class.size(0), 1).fill_(ind) 373 | seq = batch_ind, image_pred_class 374 | 375 | if not write: 376 | output = torch.cat(seq, 1) 377 | write = True 378 | else: 379 | out = torch.cat(seq, 1) 380 | output = torch.cat((output, out)) 381 | 382 | return output 383 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import re 5 | import shutil 6 | import pickle 7 | import numpy as np 8 | import scipy.io as scio 9 | from torch import Tensor as Tensor 10 | import torch 11 | from torch.utils import data 12 | from tqdm import tqdm 13 | from torchvision import transforms as T 14 | from PIL import Image 15 | 16 | color_attrs = ['Black', 'Blue', 'Brown', 17 | 'Gray', 'Green', 'Pink', 18 | 'Red', 'White', 'Yellow'] 19 | direction_attrs = ['Front', 'Rear'] 20 | type_attrs = ['passengerCar', 'saloonCar', 21 | 'shopTruck', 'suv', 'trailer', 'truck', 'van', 'waggon'] 22 | 23 | 24 | class Vehicle(data.Dataset): 25 | """ 26 | 属性向量多标签:配合cross entropy loss的使用 27 | 使用处理过的数据: 去掉所有的unknown 28 | """ 29 | 30 | def __init__(self, 31 | root, 32 | transform=None, 33 | is_train=True): 34 | """ 35 | :return: 36 | """ 37 | if not os.path.exists(root): 38 | print('=> [Err]: root not exists.') 39 | return 40 | if is_train: 41 | print('=> train data root: ', root) 42 | else: 43 | print('=> test data root: ', root) 44 | 45 | # 统计非空子目录并按名称(类别名称)自然排序 46 | self.img_dirs = [os.path.join(root, x) for x in os.listdir(root) \ 47 | if os.path.isdir(os.path.join(root, x))] 48 | self.img_dirs = [x for x in self.img_dirs if len(os.listdir(x)) != 0] 49 | if len(self.img_dirs) == 0: 50 | print('=> [Err]: empty sub-dirs.') 51 | return 52 | self.img_dirs.sort() # 默认自然排序, 从小到大 53 | # print('=> total {:d} classes for training'.format(len(self.img_dirs))) 54 | 55 | # 将多标签分开 56 | self.color_attrs = color_attrs 57 | self.direction_attrs = direction_attrs 58 | self.type_attrs = type_attrs 59 | 60 | # 按子目录(类名)的顺序排序文件路径 61 | self.imgs_path = [] 62 | self.labels = [] 63 | for x in self.img_dirs: 64 | match = re.match('([a-zA-Z]+)_([a-zA-Z]+)_([a-zA-Z]+)', os.path.split(x)[1]) 65 | color = match.group(1) # 车身颜色 66 | direction = match.group(2) # 车身方向 67 | type = match.group(3) # 车身类型 68 | # print('=> color: %s, direction: %s, type: %s' % (color, direction, type)) 69 | 70 | for y in os.listdir(x): 71 | # 添加文件路径 72 | self.imgs_path.append(os.path.join(x, y)) 73 | 74 | # 添加label 75 | color_idx = int(np.where(self.color_attrs == np.array(color))[0]) 76 | direction_idx = int(np.where(self.direction_attrs == np.array(direction))[0]) 77 | type_idx = int(np.where(self.type_attrs == np.array(type))[0]) 78 | label = np.array([color_idx, direction_idx, type_idx], dtype=int) 79 | 80 | label = torch.Tensor(label) # torch.from_numpy(label) 81 | self.labels.append(label) # Tensor(label) 82 | # print(label) 83 | 84 | if is_train: 85 | print('=> total {:d} samples for training.'.format(len(self.imgs_path))) 86 | else: 87 | print('=> total {:d} samples for testing.'.format(len(self.imgs_path))) 88 | 89 | # 加载数据变换 90 | if transform is not None: 91 | self.transform = transform 92 | else: # default image transformation 93 | self.transform = T.Compose([ 94 | T.Resize(448), 95 | T.CenterCrop(448), 96 | T.ToTensor(), 97 | T.Normalize(mean=[0.485, 0.456, 0.406], 98 | std=[0.229, 0.224, 0.225]) 99 | ]) 100 | 101 | # --------------------- serialize imgs_path to disk 102 | # root_parent = os.path.abspath(os.path.join(root, '..')) 103 | # print('=> parent dir: ', root_parent) 104 | # if is_train: 105 | # imgs_path = os.path.join(root_parent, 'train_imgs_path.pkl') 106 | # else: 107 | # imgs_path = os.path.join(ropytorch docot_parent, 'test_imgs_path.pkl') 108 | # print('=> dump imgs path: ', imgs_path) 109 | # pickle.dump(self.imgs_path, open(imgs_path, 'wb')) 110 | 111 | def __getitem__(self, idx): 112 | """ 113 | :param idx: 114 | :return: 115 | """ 116 | image = Image.open(self.imgs_path[idx]) 117 | 118 | # 数据变换, 灰度图转换成'RGB' 119 | if image.mode == 'L' or image.mode == 'I': # 8bit或32bit灰度图 120 | image = image.convert('RGB') 121 | if self.transform is not None: 122 | image = self.transform(image) 123 | label = self.labels[idx] 124 | f_path = os.path.split(self.imgs_path[idx])[0].split('/')[-2] + \ 125 | '/' + os.path.split(self.imgs_path[idx])[0].split('/')[-1] + \ 126 | '/' + os.path.split(self.imgs_path[idx])[1] 127 | return image, label, f_path 128 | 129 | def __len__(self): 130 | """os.path.split(self.imgs_path[idx])[0].split('/')[-2] 131 | :return: 132 | """ 133 | return len(self.imgs_path) 134 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | import cv2 9 | import matplotlib.pyplot as plt 10 | from darknet_util import count_parameters as count 11 | from darknet_util import convert2cpu as cpu 12 | from PIL import Image, ImageDraw 13 | 14 | 15 | def letterbox_image(img, inp_dim): 16 | ''' 17 | resize image with unchanged aspect ratio using padding 18 | ''' 19 | img_w, img_h = img.shape[1], img.shape[0] 20 | w, h = inp_dim 21 | new_w = int(img_w * min(w / img_w, h / img_h)) 22 | new_h = int(img_h * min(w / img_w, h / img_h)) 23 | resized_image = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_CUBIC) 24 | 25 | canvas = np.full((inp_dim[1], inp_dim[0], 3), 128) 26 | canvas[(h - new_h) // 2:(h - new_h) // 2 + new_h, (w - new_w) // 2:(w - new_w) // 2 + new_w, :] = resized_image 27 | 28 | return canvas 29 | 30 | 31 | def prep_image(img, inp_dim): 32 | """ 33 | Prepare image for inputting to the neural network. 34 | Returns a Tensor or Variable 35 | """ 36 | orig_im = cv2.imread(img) 37 | dim = orig_im.shape[1], orig_im.shape[0] # 图像原始宽高 38 | img = (letterbox_image(orig_im, (inp_dim, inp_dim))) 39 | img_ = img[:, :, ::-1].transpose((2, 0, 1)).copy() # BGR->RGB and WxHxchans => chansxWxH 40 | img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0) 41 | return img_, orig_im, dim 42 | 43 | 44 | def process_img(img, inp_dim): 45 | """ 46 | input PIL img, return processed img 47 | """ 48 | dim = img.width, img.height 49 | img = (letterbox_image(np.asarray(img), (inp_dim, inp_dim))) 50 | img_ = img.transpose((2, 0, 1)).copy() # WxHxchans => chansxWxH 51 | img_ = torch.from_numpy(img_).float().div(255.0).unsqueeze(0) 52 | return img_ 53 | 54 | 55 | def prep_image_pil(img, network_dim): 56 | orig_im = Image.open(img) 57 | img = orig_im.convert('RGB') 58 | dim = img.size 59 | img = img.resize(network_dim) 60 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())) 61 | img = img.view(*network_dim, 3).transpose(0, 1).transpose(0, 2).contiguous() 62 | img = img.view(1, 3, *network_dim) 63 | img = img.float().div(255.0) 64 | return (img, orig_im, dim) 65 | 66 | 67 | def inp_to_image(inp): 68 | inp = inp.cpu().squeeze() 69 | inp = inp * 255 70 | try: 71 | inp = inp.data.numpy() 72 | except RuntimeError: 73 | inp = inp.numpy() 74 | inp = inp.transpose(1, 2, 0) 75 | 76 | inp = inp[:, :, ::-1] 77 | return inp 78 | -------------------------------------------------------------------------------- /test_imgs/test_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_0.jpg -------------------------------------------------------------------------------- /test_imgs/test_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_1.jpg -------------------------------------------------------------------------------- /test_imgs/test_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_10.jpg -------------------------------------------------------------------------------- /test_imgs/test_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_11.jpg -------------------------------------------------------------------------------- /test_imgs/test_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_12.jpg -------------------------------------------------------------------------------- /test_imgs/test_13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_13.jpg -------------------------------------------------------------------------------- /test_imgs/test_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_14.jpg -------------------------------------------------------------------------------- /test_imgs/test_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_15.jpg -------------------------------------------------------------------------------- /test_imgs/test_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_16.jpg -------------------------------------------------------------------------------- /test_imgs/test_17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_17.jpg -------------------------------------------------------------------------------- /test_imgs/test_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_18.jpg -------------------------------------------------------------------------------- /test_imgs/test_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_19.jpg -------------------------------------------------------------------------------- /test_imgs/test_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_2.jpg -------------------------------------------------------------------------------- /test_imgs/test_20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_20.jpg -------------------------------------------------------------------------------- /test_imgs/test_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_3.jpg -------------------------------------------------------------------------------- /test_imgs/test_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_4.jpg -------------------------------------------------------------------------------- /test_imgs/test_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_5.jpg -------------------------------------------------------------------------------- /test_imgs/test_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_6.jpg -------------------------------------------------------------------------------- /test_imgs/test_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_7.jpg -------------------------------------------------------------------------------- /test_imgs/test_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_8.jpg -------------------------------------------------------------------------------- /test_imgs/test_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_imgs/test_9.jpg -------------------------------------------------------------------------------- /test_result/test_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_0.jpg -------------------------------------------------------------------------------- /test_result/test_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_1.jpg -------------------------------------------------------------------------------- /test_result/test_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_10.jpg -------------------------------------------------------------------------------- /test_result/test_11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_11.jpg -------------------------------------------------------------------------------- /test_result/test_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_12.jpg -------------------------------------------------------------------------------- /test_result/test_13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_13.jpg -------------------------------------------------------------------------------- /test_result/test_14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_14.jpg -------------------------------------------------------------------------------- /test_result/test_15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_15.jpg -------------------------------------------------------------------------------- /test_result/test_16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_16.jpg -------------------------------------------------------------------------------- /test_result/test_17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_17.jpg -------------------------------------------------------------------------------- /test_result/test_18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_18.jpg -------------------------------------------------------------------------------- /test_result/test_19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_19.jpg -------------------------------------------------------------------------------- /test_result/test_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_2.jpg -------------------------------------------------------------------------------- /test_result/test_20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_20.jpg -------------------------------------------------------------------------------- /test_result/test_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_3.jpg -------------------------------------------------------------------------------- /test_result/test_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_4.jpg -------------------------------------------------------------------------------- /test_result/test_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_5.jpg -------------------------------------------------------------------------------- /test_result/test_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_6.jpg -------------------------------------------------------------------------------- /test_result/test_7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_7.jpg -------------------------------------------------------------------------------- /test_result/test_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_8.jpg -------------------------------------------------------------------------------- /test_result/test_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaptainEven/Vehicle-Car-detection-and-multilabel-classification/0b0ab3ad8478c5a0ac29819b4fce3ae110d44d82/test_result/test_9.jpg -------------------------------------------------------------------------------- /train_vehicle_multilabel.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import re 5 | import shutil 6 | import time 7 | import pickle 8 | import torch 9 | import torchvision 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import dataset 13 | from dataset import color_attrs, direction_attrs, type_attrs 14 | 15 | from copy import deepcopy 16 | from PIL import Image 17 | 18 | from torchvision.datasets import ImageFolder 19 | 20 | from copy import deepcopy 21 | from torchvision import transforms as T 22 | from PIL import Image 23 | from tqdm import tqdm 24 | 25 | # print('=> torch version: ', torch.__version__) 26 | 27 | is_remote = False 28 | use_cuda = True # True 29 | 30 | os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' 31 | if is_remote: # remote side 32 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' # users can modify this according to needs and hardware 33 | device = torch.device( 34 | 'cuda: 0' if torch.cuda.is_available() and use_cuda else 'cpu') 35 | else: # local side 36 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 37 | device = torch.device( 38 | 'cuda: 0' if torch.cuda.is_available() and use_cuda else 'cpu') 39 | 40 | if use_cuda: 41 | torch.manual_seed(0) 42 | torch.cuda.manual_seed_all(0) 43 | 44 | # print('=> device: ', device) 45 | 46 | 47 | class Classifier(torch.nn.Module): 48 | """ 49 | vehicle multilabel-classifier 50 | """ 51 | 52 | def __init__(self, num_cls, input_size, is_freeze=True): 53 | """ 54 | :param is_freeze: 55 | """ 56 | torch.nn.Module.__init__(self) 57 | 58 | # output channels 59 | self._num_cls = num_cls 60 | 61 | # input image size 62 | self.input_size = input_size 63 | 64 | self._is_freeze = is_freeze 65 | print('=> is freeze: {}'.format(self._is_freeze)) 66 | 67 | # delete origin FC and add custom FC 68 | self.features = torchvision.models.resnet18(pretrained=True) # True 69 | del self.features.fc 70 | # print('feature extractor:\n', self.features) 71 | 72 | self.features = torch.nn.Sequential( 73 | *list(self.features.children())) 74 | 75 | self.fc = torch.nn.Linear(512 ** 2, num_cls) # output channels 76 | # print('=> fc layer:\n', self.fc) 77 | 78 | # -----------whether to freeze 79 | if self._is_freeze: 80 | for param in self.features.parameters(): 81 | param.requires_grad = False 82 | 83 | # init FC layer 84 | torch.nn.init.kaiming_normal_(self.fc.weight.data) 85 | if self.fc.bias is not None: 86 | torch.nn.init.constant_(self.fc.bias.data, val=0) 87 | 88 | def forward(self, X): 89 | """ 90 | :param X: 91 | :return: 92 | """ 93 | N = X.size()[0] 94 | 95 | # assert X.size() == (N, 3, self.input_size, self.input_size) 96 | 97 | X = self.features(X) # extract features 98 | 99 | # print('X.size: ', X.size()) 100 | # assert X.size() == (N, 512, 1, 1) 101 | 102 | X = X.view(N, 512, 1 ** 2) 103 | X = torch.bmm(X, torch.transpose(X, 1, 2)) / (1 ** 2) # Bi-linear CNN for fine-grained classification 104 | 105 | # assert X.size() == (N, 512, 512) 106 | 107 | X = X.view(N, 512 ** 2) 108 | X = torch.sqrt(X + 1e-5) 109 | X = torch.nn.functional.normalize(X) 110 | X = self.fc(X) 111 | 112 | assert X.size() == (N, self._num_cls) 113 | return X 114 | 115 | 116 | class Manager(object): 117 | """ 118 | train and test manager 119 | """ 120 | def __init__(self, options, path): 121 | """ 122 | model initialization 123 | """ 124 | self.options = options 125 | self.path = path 126 | 127 | # get latest model checkpoint 128 | if self.options['is_resume']: 129 | if int(self.path['model_id']) == -1: 130 | checkpoints = os.listdir(self.path['net']) 131 | checkpoints.sort(key=lambda x: int(re.match('epoch_(\d+)\.pth', x).group(1)), 132 | reverse=True) 133 | if len(checkpoints) != 0: 134 | self.LATEST_MODEL_ID = int( 135 | re.match('epoch_(\d+)\.pth', checkpoints[0]).group(1)) 136 | else: 137 | self.LATEST_MODEL_ID = int(self.path['model_id']) 138 | else: 139 | self.LATEST_MODEL_ID = 0 140 | print('=> latest net id: {}'.format(self.LATEST_MODEL_ID)) 141 | 142 | # net config 143 | if is_remote: 144 | self.net = Classifier(num_cls=19, # 19 = len(color_attrs) + len(direction_attrs) + len(type_attrs) 145 | input_size=224, 146 | is_freeze=self.options['is_freeze']).to(device) 147 | else: 148 | self.net = Classifier(num_cls=19, 149 | input_size=224, 150 | is_freeze=self.options['is_freeze']).to(device) 151 | 152 | # whether to resume from checkpoint 153 | if self.options['is_resume']: 154 | if int(self.path['model_id']) == -1: 155 | model_path = os.path.join(self.path['net'], checkpoints[0]) 156 | else: 157 | model_path = self.path['net'] + '/' + \ 158 | 'epoch_' + self.path['model_id'] + '.pth' 159 | self.net.load_state_dict(torch.load(model_path)) 160 | print('=> net resume from {}'.format(model_path)) 161 | else: 162 | print('=> net loaded from scratch.') 163 | 164 | # loss function 165 | self.loss_func = torch.nn.CrossEntropyLoss().to(device) 166 | 167 | # Solver 168 | if self.options['is_freeze']: 169 | print('=> fine-tune only the FC layer.') 170 | self.solver = torch.optim.SGD(self.net.fc.parameters(), 171 | lr=self.options['base_lr'], 172 | momentum=0.9, 173 | weight_decay=self.options['weight_decay']) 174 | else: 175 | print('=> fine-tune all layers.') 176 | self.solver = torch.optim.SGD(self.net.parameters(), 177 | lr=self.options['base_lr'], 178 | momentum=0.9, 179 | weight_decay=self.options['weight_decay']) 180 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.solver, 181 | mode='max', 182 | factor=0.1, 183 | patience=3, 184 | verbose=True, 185 | threshold=1e-4) 186 | 187 | # train data enhancement 188 | self.train_transforms = torchvision.transforms.Compose([ 189 | torchvision.transforms.Resize( 190 | size=self.net.input_size), # Let smaller edge match 191 | torchvision.transforms.RandomHorizontalFlip(), 192 | torchvision.transforms.RandomCrop( 193 | size=self.net.input_size), 194 | torchvision.transforms.ToTensor(), 195 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 196 | std=(0.229, 0.224, 0.225)) 197 | ]) 198 | 199 | # test preprocess 200 | self.test_transforms = torchvision.transforms.Compose([ 201 | torchvision.transforms.Resize(size=self.net.input_size), 202 | torchvision.transforms.CenterCrop(size=self.net.input_size), 203 | torchvision.transforms.ToTensor(), 204 | torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), 205 | std=(0.229, 0.224, 0.225)) 206 | ]) 207 | 208 | # load train and test data 209 | if is_remote: 210 | self.train_set = dataset.Vehicle(self.path['train_data'], 211 | transform=self.test_transforms, # train_transforms 212 | is_train=True) 213 | self.test_set = dataset.Vehicle(self.path['test_data'], 214 | transform=self.test_transforms, 215 | is_train=False) 216 | else: 217 | self.train_set = dataset.Vehicle(self.path['train_data'], 218 | transform=self.test_transforms, # train_transforms 219 | is_train=True) 220 | self.test_set = dataset.Vehicle(self.path['test_data'], 221 | transform=self.test_transforms, 222 | is_train=False) 223 | self.train_loader = torch.utils.data.DataLoader(self.train_set, 224 | batch_size=self.options['batch_size'], 225 | shuffle=True, 226 | num_workers=4, 227 | pin_memory=True) 228 | self.test_loader = torch.utils.data.DataLoader(self.test_set, 229 | batch_size=1, # one image each batch for testing 230 | shuffle=False, 231 | num_workers=4, 232 | pin_memory=True) 233 | 234 | # multilabels 235 | self.color_attrs = color_attrs 236 | print('=> color attributes:\n', self.color_attrs) 237 | 238 | self.direction_attrs = direction_attrs 239 | print('=> direction attributes:\n', self.direction_attrs) 240 | 241 | self.type_attrs = type_attrs 242 | print('=> type_attributes:\n', self.type_attrs, '\n') 243 | 244 | # for storage and further analysis for err details 245 | self.err_dict = {} 246 | 247 | def train(self): 248 | """ 249 | train the network 250 | """ 251 | print('==> Training...') 252 | 253 | self.net.train() # train mode 254 | 255 | best_acc = 0.0 256 | best_epoch = None 257 | 258 | print('=> Epoch\tTrain loss\tTrain acc\tTest acc') 259 | for t in range(self.options['epochs']): # traverse each epoch 260 | epoch_loss = [] 261 | num_correct = 0 262 | num_total = 0 263 | 264 | for data, label, _ in self.train_loader: # traverse each batch in the epoch 265 | # put training data, label to device 266 | data, label = data.to(device), label.to(device) 267 | 268 | # clear the grad 269 | self.solver.zero_grad() 270 | 271 | # forword calculation 272 | output = self.net.forward(data) 273 | 274 | # calculate each attribute loss 275 | label = label.long() 276 | loss_color = self.loss_func(output[:, :9], label[:, 0]) 277 | loss_direction = self.loss_func(output[:, 9:11], label[:, 1]) 278 | loss_type = self.loss_func(output[:, 11:], label[:, 2]) 279 | loss = loss_color + loss_direction + 2.0 * loss_type # greater weight to type 280 | 281 | # statistics of each epoch loss 282 | epoch_loss.append(loss.item()) 283 | 284 | # statistics of sample number 285 | num_total += label.size(0) 286 | 287 | # statistics of accuracy 288 | pred = self.get_predict(output) 289 | label = label.cpu().long() 290 | num_correct += self.count_correct(pred, label) 291 | 292 | # backward calculation according to loss 293 | loss.backward() 294 | self.solver.step() 295 | 296 | # calculate training accuray 297 | train_acc = 100.0 * float(num_correct) / float(num_total) 298 | 299 | # calculate accuracy of test set 300 | test_acc = self.test_accuracy(self.test_loader, is_draw=False) 301 | 302 | # schedule the learning rate according to test acc 303 | self.scheduler.step(test_acc) 304 | 305 | if test_acc > best_acc: 306 | best_acc = test_acc 307 | best_epoch = t + 1 308 | 309 | # dump model to disk 310 | model_save_name = 'epoch_' + \ 311 | str(t + self.LATEST_MODEL_ID + 1) + '.pth' 312 | torch.save(self.net.state_dict(), 313 | os.path.join(self.path['net'], model_save_name)) 314 | print('<= {} saved.'.format(model_save_name)) 315 | print('\t%d \t%4.3f \t\t%4.2f%% \t\t%4.2f%%' % 316 | (t + 1, sum(epoch_loss) / len(epoch_loss), train_acc, test_acc)) 317 | 318 | # statistics of details of each epoch 319 | err_dict_path = './err_dict.pkl' 320 | pickle.dump(self.err_dict, open(err_dict_path, 'wb')) 321 | print('=> err_dict dumped @ %s' % err_dict_path) 322 | self.err_dict = {} # reset err dict 323 | 324 | print('=> Best at epoch %d, test accuaray %f' % (best_epoch, best_acc)) 325 | 326 | def test_accuracy(self, data_loader, is_draw=False): 327 | """ 328 | multi-label test acc 329 | """ 330 | self.net.eval() # test mode 331 | 332 | num_correct = 0 333 | num_total = 0 334 | 335 | # counters 336 | num_color = 0 337 | num_direction = 0 338 | num_type = 0 339 | total_time = 0.0 340 | 341 | print('=> testing...') 342 | for data, label, f_name in data_loader: 343 | # place data in device 344 | if is_draw: 345 | img = data.cpu()[0] 346 | img = self.ivt_tensor_img(img) # Tensor -> image 347 | data, label = data.to(device), label.to(device) 348 | 349 | # format label 350 | label = label.cpu().long() 351 | 352 | start = time.time() 353 | 354 | # forward calculation and processing output 355 | output = self.net.forward(data) 356 | pred = self.get_predict(output) # return to cpu 357 | 358 | # time consuming 359 | end = time.time() 360 | total_time += float(end - start) 361 | if is_draw: 362 | print('=> classifying time: {:2.3f} ms'.format( 363 | 1000.0 * (end - start))) 364 | 365 | # count total number 366 | num_total += label.size(0) 367 | 368 | # count each attribute acc 369 | color_name = self.color_attrs[pred[0][0]] 370 | direction_name = self.direction_attrs[pred[0][1]] 371 | type_name = self.type_attrs[pred[0][2]] 372 | 373 | if is_draw: 374 | fig = plt.figure(figsize=(6, 6)) 375 | plt.imshow(img) 376 | plt.title(color_name + ' ' + direction_name + ' ' + type_name) 377 | plt.show() 378 | 379 | # num_correct += self.count_correct(pred, label) 380 | num_correct += self.statistics_result(pred, label, f_name) 381 | 382 | # calculate acc of each attribute 383 | num_color += self.count_attrib_correct(pred, label, 0) 384 | num_direction += self.count_attrib_correct(pred, label, 1) 385 | num_type += self.count_attrib_correct(pred, label, 2) 386 | 387 | # calculate time consuming of inference 388 | print('=> average inference time: {:2.3f} ms'.format( 389 | 1000.0 * total_time / float(len(data_loader)))) 390 | 391 | accuracy = 100.0 * float(num_correct) / float(num_total) 392 | color_acc = 100.0 * float(num_color) / float(num_total) 393 | direction_acc = 100.0 * float(num_direction) / float(num_total) 394 | type_acc = 100.0 * float(num_type) / float(num_total) 395 | 396 | print( 397 | '=> test accuracy: {:.3f}% | color acc: {:.3f}%, direction acc: {:.3f}%, type acc: {:.3f}%'.format( 398 | accuracy, color_acc, direction_acc, type_acc)) 399 | return accuracy 400 | 401 | def get_predict(self, output): 402 | """ 403 | processing output 404 | :param output: 405 | :return: prediction 406 | """ 407 | # get prediction for each label 408 | output = output.cpu() # get data back to cpu side 409 | pred_color = output[:, :9] 410 | pred_direction = output[:, 9:11] 411 | pred_type = output[:, 11:] 412 | 413 | color_idx = pred_color.max(1, keepdim=True)[1] 414 | direction_idx = pred_direction.max(1, keepdim=True)[1] 415 | type_idx = pred_type.max(1, keepdim=True)[1] 416 | pred = torch.cat((color_idx, direction_idx, type_idx), dim=1) 417 | return pred 418 | 419 | def count_correct(self, pred, label): 420 | """ 421 | :param pred: 422 | :param label: 423 | :return: 424 | """ 425 | # label_cpu = label.cpu().long() # 需要将label转化成long tensor 426 | assert pred.size(0) == label.size(0) 427 | correct_num = 0 428 | for one, two in zip(pred, label): 429 | if torch.equal(one, two): 430 | correct_num += 1 431 | return correct_num 432 | 433 | def statistics_result(self, pred, label, f_name): 434 | """ 435 | statistics of correct and error 436 | :param pred: 437 | :param label: 438 | :param f_name: 439 | :return: 440 | """ 441 | # label_cpu = label.cpu().long() 442 | assert pred.size(0) == label.size(0) 443 | correct_num = 0 444 | for name, one, two in zip(f_name, pred, label): 445 | if torch.equal(one, two): # statistics of correct number 446 | correct_num += 1 447 | else: # statistics of detailed error info 448 | pred_color = self.color_attrs[one[0]] 449 | pred_direction = self.direction_attrs[one[1]] 450 | pred_type = self.type_attrs[one[2]] 451 | 452 | label_color = self.color_attrs[two[0]] 453 | label_direction = self.direction_attrs[two[1]] 454 | label_type = self.type_attrs[two[2]] 455 | err_result = label_color + ' ' + label_direction + ' ' + label_type + \ 456 | ' => ' + \ 457 | pred_color + ' ' + pred_direction + ' ' + pred_type 458 | self.err_dict[name] = err_result 459 | return correct_num 460 | 461 | def count_attrib_correct(self, pred, label, idx): 462 | """ 463 | :param pred: 464 | :param label: 465 | :param idx: 466 | :return: 467 | """ 468 | assert pred.size(0) == label.size(0) 469 | correct_num = 0 470 | for one, two in zip(pred, label): 471 | if one[idx] == two[idx]: 472 | correct_num += 1 473 | return correct_num 474 | 475 | def ivt_tensor_img(self, inp, title=None): 476 | """ 477 | Imshow for Tensor. 478 | """ 479 | 480 | # turn channelsxWxH into WxHxchannels 481 | inp = inp.numpy().transpose((1, 2, 0)) 482 | 483 | mean = np.array([0.485, 0.456, 0.406]) 484 | std = np.array([0.229, 0.224, 0.225]) 485 | 486 | # de-standardization 487 | inp = std * inp + mean 488 | 489 | # clipping 490 | inp = np.clip(inp, 0, 1) 491 | 492 | # plt.imshow(inp) 493 | # if title is not None: 494 | # plt.title(title) 495 | # plt.pause(0.001) # pause a bit so that plots are updated 496 | return inp 497 | 498 | def recognize_pil(self, image): 499 | """ 500 | classify a single image 501 | :param img: PIL Image 502 | :return: 503 | """ 504 | img = deepcopy(image) 505 | if img.mode == 'L' or img.mode == 'I': # turn 8bits or 32bits gray into RGB 506 | img = img.convert('RGB') 507 | img = self.test_transforms(img) 508 | img = img.view(1, 3, self.net.module.input_size, 509 | self.net.module.input_size) 510 | 511 | # put data to device 512 | img = img.to(device) 513 | 514 | start = time.time() 515 | 516 | # inference calculation 517 | output = self.net.forward(img) 518 | 519 | # get prediction 520 | pred = self.get_predict(output) 521 | 522 | end = time.time() 523 | 524 | print('=> classifying time: {:2.3f} ms'.format(1000.0 * (end - start))) 525 | 526 | color_name = self.color_attrs[pred[0][0]] 527 | direction_name = self.direction_attrs[pred[0][1]] 528 | type_name = self.type_attrs[pred[0][2]] 529 | 530 | # fig = plt.figure(figsize=(6, 6)) 531 | # plt.imshow(image) 532 | # plt.title(color_name + ' ' + direction_name + ' ' + type_name) 533 | # plt.show() 534 | 535 | def test_single(self): 536 | """ 537 | test single image 538 | :return: 539 | """ 540 | self.net.eval() 541 | 542 | root = '/mnt/diskc/even/Car_DR/test_set' 543 | for file in os.listdir(root): 544 | file_path = os.path.join(root, file) 545 | image = Image.open(file_path) 546 | self.recognize_pil(image) 547 | 548 | def random_pick(self, src, dst, pick_num=20): 549 | """ 550 | random pick from src to dst 551 | :param src: 552 | :param dst: 553 | :return: 554 | """ 555 | if not os.path.exists(src) or not os.path.exists(dst): 556 | print('=> [Err]: invalid dir.') 557 | return 558 | 559 | if len(os.listdir(dst)) != 0: 560 | shutil.rmtree(dst) 561 | os.mkdir(dst) 562 | 563 | # recursive traversing, search for '.jpg' 564 | jpgs_path = [] 565 | 566 | def find_jpgs(root, jpgs_path): 567 | """ 568 | :param root: 569 | :param jpgs_path: 570 | :return: 571 | """ 572 | for file in os.listdir(root): 573 | file_path = os.path.join(root, file) 574 | 575 | if os.path.isdir(file_path): # if dir do recursion 576 | find_jpgs(file_path, jpgs_path) 577 | else: # if file, put to list 578 | if os.path.isfile(file_path) and file_path.endswith('.jpg'): 579 | jpgs_path.append(file_path) 580 | 581 | find_jpgs(src, jpgs_path) 582 | # print('=> all jpgs path:\n', jpgs_path) 583 | 584 | # no replace random pick 585 | pick_ids = np.random.choice( 586 | len(jpgs_path), size=pick_num, replace=False) 587 | for id in pick_ids: 588 | shutil.copy(jpgs_path[id], dst) 589 | 590 | 591 | def run(): 592 | """ 593 | main loop function 594 | """ 595 | import argparse 596 | parser = argparse.ArgumentParser( 597 | description='Train bi-linear CNN based vehicle multilabel classification.') 598 | parser.add_argument('--base_lr', 599 | dest='base_lr', 600 | type=float, 601 | default=1.0, 602 | help='Base learning rate for training.') 603 | parser.add_argument('--batch_size', 604 | dest='batch_size', 605 | type=int, 606 | default=64, # 64 607 | help='Batch size.') # 用多卡可以设置的更大 608 | parser.add_argument('--epochs', 609 | dest='epochs', 610 | type=int, 611 | default=100, 612 | help='Epochs for training.') 613 | parser.add_argument('--weight_decay', 614 | dest='weight_decay', 615 | type=float, 616 | default=1e-8, 617 | help='Weight decay.') 618 | # parser.add_argument('--use-cuda', type=bool, default=True, 619 | # help='whether to use GPU or not.') 620 | parser.add_argument('--is-freeze', 621 | type=bool, 622 | default=True, 623 | help='whether to freeze all other layers except FC layer.') 624 | parser.add_argument('--is-resume', 625 | type=bool, 626 | default=False, 627 | help='whether to resume from checkpoints') 628 | parser.add_argument('--pre-train', 629 | type=bool, 630 | default=True, 631 | help='whether in pre training mode.') 632 | args = parser.parse_args() 633 | 634 | if args.base_lr <= 0: 635 | raise AttributeError('--base_lr parameter must > 0.') 636 | if args.batch_size <= 0: 637 | raise AttributeError('--batch_size parameter must > 0.') 638 | if args.epochs < 0: 639 | raise AttributeError('--epochs parameter must > 0.') 640 | if args.weight_decay <= 0: 641 | raise AttributeError('--weight_decay parameter must > 0.') 642 | 643 | if args.pre_train: 644 | options = { 645 | 'base_lr': args.base_lr, 646 | 'batch_size': args.batch_size, 647 | 'epochs': args.epochs, 648 | 'weight_decay': args.weight_decay, 649 | 'is_freeze': True, 650 | 'is_resume': False 651 | } 652 | else: 653 | options = { 654 | 'base_lr': args.base_lr, 655 | 'batch_size': args.batch_size, 656 | 'epochs': args.epochs, 657 | 'weight_decay': args.weight_decay, 658 | 'is_freeze': False, 659 | 'is_resume': True 660 | } 661 | 662 | # super parameters for fine-tuning 663 | if not options['is_freeze']: 664 | options['base_lr'] = 1e-3 665 | options['epochs'] = 100 666 | options['weight_decay'] = 1e-8 # 1e-8 667 | print('=> options:\n', options) 668 | 669 | parent_dir = os.path.realpath( 670 | os.path.join(os.getcwd(), '..')) + os.path.sep 671 | project_root = parent_dir 672 | print('=> project_root: ', project_root) 673 | 674 | if is_remote: # local paths 675 | path = { 676 | 'net': '/mnt/diskc/even/b_cnn/filter_test_model', 677 | 'model_id': '-1', # -1 678 | 'train_data': '/mnt/diskc/even/vehicle_train', 679 | 'test_data': '/mnt/diskc/even/vehicle_test' 680 | } 681 | else: # remote paths 682 | path = { 683 | 'net': './checkpoints', 684 | 'model_id': '-1', 685 | 'train_data': 'f:/vehicle_train', 686 | 'test_data': 'f:/vehicle_test' 687 | } 688 | 689 | manager = Manager(options, path) 690 | manager.train() 691 | # manager.test_accuracy(manager.test_loader, is_draw=True) 692 | # manager.random_pick(src='/mnt/diskc/even/Car_DR/vehicle_test', dst='/mnt/diskc/even/Car_DR/test_set') 693 | # manager.test_single() 694 | 695 | 696 | if __name__ == '__main__': 697 | run() 698 | 699 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | def merge_violet2blue(train_root): 8 | """ 9 | to simplify data labeling and training, we merge violet into blue 10 | """ 11 | sub_dirs = [train_root + '/' + x for x in os.listdir(train_root)] 12 | violets = [x for x in sub_dirs if os.path.isdir(x) and 'Violet' in x] 13 | print('=> violets:\n', violets) 14 | 15 | # merge violet files into blue 16 | for x in violets: 17 | dst_dir = x.replace('Violet', 'Blue') 18 | 19 | if os.path.exists(dst_dir): 20 | for y in os.listdir(x): 21 | src_path = x + '/' + y 22 | if os.path.exists(src_path): 23 | dst_path = dst_dir + '/' + y 24 | if not os.path.exists(dst_path): 25 | shutil.copy(src_path, dst_dir) 26 | 27 | # delete empty sub-dirs which contains violet 28 | for x in violets: 29 | if os.path.exists(x): 30 | shutil.rmtree(x) 31 | 32 | 33 | def viz_err(err_path, root='f:/'): 34 | """ 35 | visualize the detailed err info. 36 | """ 37 | err_dict = pickle.load(open(err_path, 'rb')) 38 | # print(err_dict) 39 | 40 | fig = plt.figure() # 41 | 42 | for k, v in err_dict.items(): 43 | img_path = root + k 44 | if os.path.isfile(img_path): 45 | img = Image.open(img_path) 46 | plt.gcf().set_size_inches(8, 8) 47 | plt.imshow(img) 48 | plt.title(img_path + '\n' + v) 49 | plt.gca().set_xticks([]) 50 | plt.gca().set_yticks([]) 51 | plt.show() 52 | 53 | 54 | if __name__ == '__main__': 55 | # merge_violet2blue('f:/vehicle_test') 56 | pass 57 | --------------------------------------------------------------------------------