├── README.md ├── heatmap ├── README.md ├── data │ └── 08_11 │ │ └── test │ │ ├── imgs │ │ └── 202108110759_6.jpg │ │ └── labels │ │ └── 202108110759_6.json ├── main │ ├── config_loc.py │ ├── data_pre_loc.py │ ├── models_loc.py │ ├── net_util_loc.py │ ├── test_main.py │ └── train_main.py ├── requirements.txt └── result │ └── 08_11 │ └── test │ └── 202108110759_6_keypoint.jpg └── offset ├── README.md ├── data └── 08_11 │ └── test │ ├── imgs │ └── 202107170976_4.jpg │ └── labels │ └── 202107170976_4.json ├── main ├── config.py ├── data_pre.py ├── determine_rotation_angle.py ├── models.py ├── net_util.py ├── test_main.py └── train_main.py └── result └── 08_11 └── test_data └── 202107170976_4_keypoint.jpg /README.md: -------------------------------------------------------------------------------- 1 | # 代码说明 2 | 3 | 1. 其中 offset 文件夹是很早期研究者们采用的方法,直接通过神经网络输出点坐标的方式 4 | - 这种方法存在的弊端就是预测难度很高,除了欧式距离的值没有其他任何参考 5 | 6 | 2. 另一个 heatmap 文件夹则是现在的研究者们更常用的方法,将关键点的点转换成一个heatmap类型的数据 7 | - 这种方法可以去得更好的效果,在工业检测任务上的效果图见 heatmap 文件夹中的 Readme 8 | -------------------------------------------------------------------------------- /heatmap/README.md: -------------------------------------------------------------------------------- 1 | ## KeyPoint-Detection/heatmap 2 | 采用高斯图作为网络的输出结果,模型预测的关键点精度进一步提高,误差控制在几个像素 3 | 在data文件夹中给出了一张实例图片和标注的json文件 4 | 5 | ## 主要内容 6 | * 使用了高斯图作为label,网络的输出结果为对同等维度的高斯图 7 | * 网络模型采用的是 U-net 8 | * 损失函数使用的是 torch.nn.MSELoss() 9 | * 在网络预测的高斯图中采用最大值点作为预测的关键点位置 10 | 11 | ## 模型效果 12 | 其中绿色点为标注的关键点,红色点为预测的关键点 13 | 模型效果
14 | -------------------------------------------------------------------------------- /heatmap/data/08_11/test/imgs/202108110759_6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/heatmap/data/08_11/test/imgs/202108110759_6.jpg -------------------------------------------------------------------------------- /heatmap/main/config_loc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | config_loc = { 5 | # 网络训练部分 6 | # 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), 7 | 'device': torch.device("cuda"), 8 | 'batch_size': 1, 9 | 'epochs': 510, 10 | 'save_epoch': 100, 11 | 'learning_rate': 0.0001, 12 | 'lr_scheduler': 'step1', # 可以选择'step1','step2'梯度下降,'exponential'指数下降 13 | 14 | # 原图尺寸 15 | 'img_h': 3036, 16 | 'img_w': 4024, 17 | 18 | # 裁剪后的尺寸 19 | 'cut_h': 3008, 20 | 'cut_w': 3968, 21 | 22 | # 网络输入的图像尺寸 23 | 'input_h': 752, 24 | 'input_w': 992, 25 | 26 | # 高斯核大小 27 | 'gauss_h': 51, 28 | 'gauss_w': 51, 29 | 30 | # 关键点个数 31 | 'kpt_n': 3, 32 | 33 | # 网络评估部分 34 | 'test_batch_size': 1, 35 | 'test_threshold': 0.5, 36 | 37 | 'path': '/home/lwm/Disk_D/Projects/IndustrialProjects/Haier', 38 | 39 | # 设置路径部分 40 | 'train_date': 'all_9_20220219', 41 | 'train_way': 'train', 42 | 'test_date': 'all_9_20220219', 43 | 'test_way': 'train', 44 | 45 | # 调用的模型 46 | # 'pkl_file': '20211123_2.pth', 47 | 'pkl_file': 'all_9_0216.pth', 48 | 49 | # 是否加载预训练模型 50 | 'use_old_pkl': True, 51 | 'old_pkl': 'min_loss.pth', 52 | 53 | # # pytorch < 1.6 54 | # 'pytorch_version': False, 55 | 56 | # remember location 57 | 'start_x': 200, 58 | 'start_y': 200, 59 | 'start_angle': 0, 60 | 61 | # max x,y 62 | 'max_x': 300, 63 | 'max_y': 250, 64 | 'max_angle': 90, 65 | 66 | # min x,y 67 | 'min_x': 100, 68 | 'min_y': 100, 69 | 70 | # key points relative location 71 | 'distance_12': 360, 72 | 'distance_13': 200, 73 | 'distance_23': 410, 74 | 75 | 'delta': 50, 76 | 77 | # 'photo_to_world': 78 | } 79 | -------------------------------------------------------------------------------- /heatmap/main/data_pre_loc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import math 5 | import matplotlib.pyplot as plt 6 | from config_loc import config_loc as cfg 7 | from scipy.ndimage import gaussian_filter 8 | import cv2 9 | import PIL 10 | from torchvision import transforms 11 | import torch 12 | from PIL import Image, ImageFont, ImageDraw 13 | # os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 14 | 15 | 16 | # json变成加入高斯的np 17 | def json_to_numpy(dataset_path): 18 | with open(dataset_path) as fp: 19 | json_data = json.load(fp) 20 | points = json_data['shapes'] 21 | 22 | # print(points) 23 | landmarks = [] 24 | for point in points: 25 | for p in point['points']: 26 | landmarks.append(p) 27 | 28 | # print(landmarks) 29 | landmarks = np.array(landmarks) 30 | landmarks = landmarks.reshape(-1, 2) 31 | 32 | # 保存为np 33 | # np.save(os.path.join(save_path, name.split('.')[0] + '.npy'), landmarks) 34 | 35 | return landmarks 36 | 37 | 38 | def generate_heatmaps(landmarks, height, width, sigma): 39 | 40 | # img_h = cfg['img_h'] 41 | # img_w = cfg['img_w'] 42 | 43 | cut_h = cfg['cut_h'] 44 | cut_w = cfg['cut_w'] 45 | 46 | heatmaps = [] 47 | for points in landmarks: 48 | heatmap = np.zeros((height, width)) 49 | 50 | # ch = int(height * points[1] / img_h) 51 | # cw = int(width * points[0] / img_w) 52 | ch = int(height * points[1] / cut_h) 53 | cw = int(width * points[0] / cut_w) 54 | heatmap[ch][cw] = 1 55 | 56 | heatmap = cv2.GaussianBlur(heatmap, sigma, 0) 57 | am = np.amax(heatmap) 58 | heatmap /= am / 255 59 | heatmaps.append(heatmap) 60 | 61 | heatmaps = np.array(heatmaps) 62 | # heatmaps = np.expand_dims(heatmaps, axis=0) 63 | 64 | return heatmaps 65 | 66 | 67 | def show_heatmap(heatmaps): 68 | for heatmap in heatmaps: 69 | plt.imshow(heatmap, cmap='hot', interpolation='nearest') 70 | plt.show() 71 | 72 | 73 | def heatmap_to_point(heatmaps): 74 | # img_h = cfg['img_h'] 75 | # img_w = cfg['img_w'] 76 | cut_h = cfg['cut_h'] 77 | cut_w = cfg['cut_w'] 78 | input_h = cfg['input_h'] 79 | input_w = cfg['input_w'] 80 | 81 | points = [] 82 | for heatmap in heatmaps: 83 | pos = np.unravel_index(np.argmax(heatmap), heatmap.shape) 84 | point0 = cut_w * (pos[0] / input_w) 85 | point1 = cut_h * (pos[1] / input_h) 86 | points.append([point1, point0]) 87 | return np.array(points) 88 | 89 | 90 | def show_inputImg_and_keypointLabel(imgPath, heatmaps): 91 | points = [] 92 | for heatmap in heatmaps: 93 | pos = np.unravel_index(np.argmax(heatmap), heatmap.shape) 94 | points.append([pos[1], pos[0]]) 95 | 96 | img = PIL.Image.open(imgPath).convert('RGB') 97 | img = transforms.ToTensor()(img) # 3*3000*4096 98 | img = img[:, :cfg['cut_h'], :cfg['cut_w']] 99 | 100 | img = img.unsqueeze(0) # 增加一维 101 | resize = torch.nn.Upsample(scale_factor=(0.25, 0.25), mode='bilinear', align_corners=True) 102 | img = resize(img) 103 | 104 | img = img.squeeze(0) # 减少一维 105 | 106 | print(img.shape) 107 | 108 | img = transforms.ToPILImage()(img) 109 | draw = ImageDraw.Draw(img) 110 | for point in points: 111 | print(point) 112 | draw.point((point[0], point[1]), fill='yellow') 113 | 114 | # 保存 115 | img.save(os.path.join('..','show', 'out.jpg')) 116 | 117 | 118 | if __name__ == '__main__': 119 | landmarks = json_to_numpy('../data/0828_back/labels/202108280005_2.json') 120 | print('关键点坐标', landmarks, '-------------', sep='\n') 121 | 122 | heatmaps = generate_heatmaps(landmarks, cfg['input_h'], cfg['input_w'], (cfg['gauss_h'], cfg['gauss_w'])) 123 | # print(heatmaps) 124 | print(heatmaps.shape) 125 | 126 | # show heatmap picture 127 | # show_heatmap(heatmaps) 128 | 129 | # show cut image and the keypoints 130 | # show_inputImg_and_keypointLabel('../data/08_11_in/train/imgs/202108110556_6.jpg', heatmaps) 131 | -------------------------------------------------------------------------------- /heatmap/main/models_loc.py: -------------------------------------------------------------------------------- 1 | from torchsummaryX import summary 2 | from net_util_loc import * 3 | from torchviz import make_dot 4 | import tensorwatch as tw 5 | from tensorboardX import SummaryWriter 6 | 7 | 8 | # Unet的下采样模块,两次卷积 9 | class DoubleConv(nn.Module): 10 | 11 | def __init__(self, in_channels, out_channels, channel_reduce=False): # 只是定义网络中需要用到的方法 12 | super(DoubleConv, self).__init__() 13 | 14 | # 通道减少的系数 15 | coefficient = 2 if channel_reduce else 1 16 | 17 | self.down = nn.Sequential( 18 | nn.Conv2d(in_channels, coefficient * out_channels, kernel_size=(3, 3), padding=1), 19 | nn.BatchNorm2d(coefficient * out_channels), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(coefficient * out_channels, out_channels, kernel_size=(3, 3), padding=1), 22 | nn.BatchNorm2d(out_channels), 23 | nn.ReLU(inplace=True) 24 | ) 25 | 26 | def forward(self, x): 27 | return self.down(x) 28 | 29 | 30 | # 上采样(转置卷积加残差链接) 31 | class Up(nn.Module): 32 | 33 | # 千万注意输入,in_channels是要送入二次卷积的channel,out_channels是二次卷积之后的channel 34 | def __init__(self, in_channels, out_channels): 35 | super().__init__() 36 | # 先上采样特征图 37 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=4, stride=2, padding=1) 38 | self.conv = DoubleConv(in_channels, out_channels, channel_reduce=True) 39 | 40 | def forward(self, x1, x2): 41 | x1 = self.up(x1) 42 | 43 | # print(x1.shape, x2.shape) 44 | x = torch.cat([x1, x2], dim=1) 45 | x = self.conv(x) 46 | return x 47 | 48 | 49 | # simple U-net模型 50 | class U_net(nn.Module): 51 | 52 | def __init__(self): # 只是定义网络中需要用到的方法 53 | super(U_net, self).__init__() 54 | 55 | # 下采样 56 | self.double_conv1 = DoubleConv(3, 32) 57 | self.double_conv2 = DoubleConv(32, 64) 58 | self.double_conv3 = DoubleConv(64, 128) 59 | self.double_conv4 = DoubleConv(128, 256) 60 | self.double_conv5 = DoubleConv(256, 256) 61 | 62 | # 上采样 63 | self.up1 = Up(512, 128) 64 | self.up2 = Up(256, 64) 65 | self.up3 = Up(128, 32) 66 | self.up4 = Up(64, 16) 67 | 68 | # 最后一层 69 | self.out = nn.Conv2d(16, cfg['kpt_n'], kernel_size=(1, 1), padding=0) 70 | 71 | def forward(self, x): 72 | # down 73 | c1 = self.double_conv1(x) # (,32,512,512) 74 | p1 = nn.MaxPool2d(2)(c1) # (,32,256,256) 75 | c2 = self.double_conv2(p1) # (,64,256,256) 76 | p2 = nn.MaxPool2d(2)(c2) # (,64,128,128) 77 | c3 = self.double_conv3(p2) # (,128,128,128) 78 | p3 = nn.MaxPool2d(2)(c3) # (,128,64,64) 79 | c4 = self.double_conv4(p3) # (,256,64,64) 80 | p4 = nn.MaxPool2d(2)(c4) # (,256,32,32) 81 | c5 = self.double_conv5(p4) # (,256,32,32) 82 | # 最后一次卷积不做池化操作 83 | 84 | # up 85 | u1 = self.up1(c5, c4) # (,128,64,64) 86 | u2 = self.up2(u1, c3) # (,64,128,128) 87 | u3 = self.up3(u2, c2) # (,32,256,256) 88 | u4 = self.up4(u3, c1) # (,16,512,512) 89 | 90 | # 最后一层,隐射到3个特征图 91 | out = self.out(u4) 92 | 93 | return out 94 | 95 | def summary(self, net): 96 | x = torch.rand(cfg['batch_size'], 3, cfg['input_h'], cfg['input_w']) # 352*512 97 | # 送入设备 98 | x = x.to(cfg['device']) 99 | # 输出y的shape 100 | # print(net(x).shape) 101 | 102 | # 展示网络结构 103 | summary(net, x) 104 | 105 | 106 | # 主函数调试 107 | if __name__ == "__main__": 108 | m = U_net().to(cfg['device']) 109 | 110 | m.summary(m) 111 | img = tw.draw_model(m, [cfg['batch_size'], 3, cfg['input_h'], cfg['input_w']]) 112 | # # print(img) 113 | img.save('/home/mlg1504/bolt_project/algorithm/keypoint_det/code2/U-net.png') 114 | 115 | x = torch.rand(cfg['batch_size'], 3, cfg['input_h'], cfg['input_w']) 116 | x = x.to(cfg['device']) 117 | # model = U_net() 118 | # with SummaryWriter(comment='U-Net') as w: 119 | # w.add_graph(m, x) 120 | 121 | # y = m(x) 122 | # 123 | # g = make_dot(y) 124 | # 125 | # g.render('espnet_model', view=False) 126 | 127 | -------------------------------------------------------------------------------- /heatmap/main/net_util_loc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from torch import nn 5 | import torchvision 6 | from config_loc import config_loc as cfg 7 | import torch.utils.data 8 | from torchvision import datasets, transforms, models 9 | import cv2 10 | import PIL 11 | from PIL import Image, ImageFont, ImageDraw 12 | from data_pre_loc import json_to_numpy, generate_heatmaps 13 | 14 | 15 | # box_3D的数据仓库 16 | class Dataset(torch.utils.data.Dataset): 17 | # 初始化 18 | def __init__(self, dataset_path): 19 | self.dataset_path = dataset_path 20 | self.img_name_list = os.listdir(os.path.join(dataset_path, cfg['train_way'], 'imgs')) 21 | 22 | # 根据 index 返回位置的图像和label 23 | def __getitem__(self, index): 24 | # 先处理img 25 | img_name = self.img_name_list[index] 26 | img = PIL.Image.open(os.path.join(self.dataset_path, cfg['train_way'], 'imgs', img_name)).convert('RGB') 27 | img = transforms.ToTensor()(img) # 3*3000*4096 28 | img = img[:, :cfg['cut_h'], :cfg['cut_w']] 29 | 30 | img = img.unsqueeze(0) # 增加一维 31 | resize = torch.nn.Upsample(scale_factor=(0.25, 0.25), mode='bilinear', align_corners=True) 32 | img = resize(img).squeeze(0) # 33 | # print(img.shape) 34 | 35 | # 读入标签 36 | mask_name = img_name.split('.')[0] + '.json' 37 | mask = json_to_numpy(os.path.join(self.dataset_path, cfg['train_way'], 'labels', mask_name)) 38 | # mask = np.load(os.path.join(self.dataset_path, 'labels', self.img_name_list[index].split('.')[0] + '.json'),allow_pickle=True) 39 | # mask = torch.tensor(mask, dtype=torch.float32) 40 | 41 | heatmaps = generate_heatmaps(mask, cfg['input_h'], cfg['input_w'], (cfg['gauss_h'], cfg['gauss_w'])) 42 | heatmaps = torch.tensor(heatmaps, dtype=torch.float32) 43 | 44 | return img, heatmaps, img_name 45 | 46 | # 数据集的大小 47 | def __len__(self): 48 | return len(self.img_name_list) 49 | -------------------------------------------------------------------------------- /heatmap/main/test_main.py: -------------------------------------------------------------------------------- 1 | from net_util_loc import * 2 | import os 3 | from models_loc import * 4 | import cv2 5 | import PIL 6 | import xml.etree.cElementTree as ET 7 | from xml.etree import ElementTree 8 | import numpy as np 9 | from xml.dom import minidom 10 | import time 11 | # from determine_rotation_angle_loc import calculate_rotation_angle 12 | from determine_location_loc import determine_location 13 | from data_pre_loc import json_to_numpy, generate_heatmaps, heatmap_to_point 14 | 15 | 16 | def show_point_on_picture(img, landmarks, landmarks_gt): 17 | for point in landmarks: 18 | point = tuple([int(point[0]), int(point[1])]) 19 | # print(point) 20 | img = cv2.circle(img, center=point, radius=20, color=(0, 0, 255), thickness=-1) 21 | for point in landmarks_gt: 22 | point = tuple([int(point[0]), int(point[1])]) 23 | # print(point) 24 | img = cv2.circle(img, center=point, radius=20, color=(0, 255, 0), thickness=-1) 25 | return img 26 | 27 | 28 | # 预测 29 | def evaluate(flag=False): 30 | date = cfg['test_date'] 31 | way = cfg['test_way'] 32 | 33 | # 测试路径 34 | img_path = os.path.join('..', 'data', date, way, 'imgs') 35 | # 测试集坐标 36 | label_path = os.path.join('..', 'data', date, way, 'labels') 37 | 38 | # 定义模型 39 | model = U_net() 40 | 41 | # 034, 128 42 | model.load_state_dict(torch.load(os.path.join('..', 'weights', cfg['pkl_file']))) 43 | model.to(cfg['device']) 44 | # model.summary(model) 45 | model.eval() 46 | 47 | # 下采样模型 48 | resize = torch.nn.Upsample(scale_factor=(1, 0.5), mode='bilinear', align_corners=True) 49 | 50 | total_loss = 0 51 | diff_angle_list = [] 52 | diff_delta_x_list = [] 53 | diff_delta_y_list = [] 54 | max_keypoint_diff = 0 55 | 56 | # 开始预测 57 | for index, name in enumerate(os.listdir(img_path)): 58 | print('图像名称:', name+"  图像编号:"+str(index+1)) 59 | 60 | # img = cv2.imread(os.path.join(img_path, name)) 61 | # img = cv2.resize(img, (cfg['input_w'], cfg['input_h'])) 62 | # img = transforms.ToTensor()(img) 63 | # img = torch.unsqueeze(img, dim=0) # 训练时采用的是DataLoader函数, 会直接增加第一个维度 64 | 65 | img = PIL.Image.open(os.path.join(img_path, name)).convert('RGB') 66 | img = transforms.ToTensor()(img) # 3*3000*4096 67 | img = img[:, :cfg['cut_h'], :cfg['cut_w']] 68 | 69 | img = img.unsqueeze(0) # 增加一维 70 | resize = torch.nn.Upsample(scale_factor=(0.25, 0.25), mode='bilinear', align_corners=True) 71 | img = resize(img) 72 | 73 | print('输入网络的图片维度信息:', img.shape) 74 | 75 | # 喂入网络 76 | img = img.to(cfg['device']) 77 | 78 | pre = model(img) 79 | pre = pre.cpu().detach().numpy() 80 | # pre = pre.reshape(pre.shape[0], -1, 2) 81 | 82 | pre_point = heatmap_to_point(pre[0]) 83 | 84 | point = json_to_numpy(os.path.join(label_path, name.split('.')[0] + '.json')) 85 | 86 | print('图片的jpg文件位置:', os.path.join(img_path, name)) 87 | print('标注的json文件位置:', os.path.join(label_path, name.split('.')[0] + '.json')) 88 | 89 | print('预测的关键点坐标:\n', pre_point) 90 | print('真实的关键点坐标:\n', point) 91 | 92 | pre_label = torch.Tensor(pre_point.reshape(1, -1)).to(cfg['device']) 93 | label = torch.Tensor(point.reshape(1, -1)).to(cfg['device']) 94 | 95 | loss_F = torch.nn.MSELoss() 96 | loss_F.to(cfg['device']) 97 | loss = loss_F(pre_label, label) # 计算损失 98 | 99 | print('+++坐标误差损失: ', loss.item()) 100 | total_loss += loss.item() 101 | 102 | if loss.item() > max_keypoint_diff: 103 | max_keypoint_diff = loss.item() 104 | 105 | print('---------') 106 | 107 | if flag == True: 108 | del img 109 | 110 | img = cv2.imread(os.path.join(img_path, name)) 111 | img = show_point_on_picture(img, pre_point, point) 112 | 113 | # 存储绘制图像部分 114 | save_dir = os.path.join('..', 'result', date, way + '_data') 115 | if not os.path.exists(save_dir): 116 | os.makedirs(save_dir) 117 | 118 | result_dir = os.path.join(save_dir, name.split('.')[0] + '_keypoint.jpg') 119 | print('绘制关键点后图像的存储位置:', result_dir) 120 | cv2.imwrite(result_dir, img) 121 | 122 | 123 | print('##################') 124 | print('# ---- Mean ---- #') 125 | print('##################') 126 | 127 | print('平均每个关键点坐标误差:', total_loss / (index + 1), '  最大单个关键点坐标误差:', max_keypoint_diff) 128 | 129 | 130 | if __name__ == "__main__": 131 | # choose weights 132 | # choose_weights() 133 | 134 | # 对一组权重进行预测 135 | evaluate(flag=True) 136 | -------------------------------------------------------------------------------- /heatmap/main/train_main.py: -------------------------------------------------------------------------------- 1 | from models_loc import * 2 | from config_loc import config_loc as cfg 3 | 4 | 5 | # 训练主函数入口 6 | def train(): 7 | print('start') 8 | 9 | save_epoch = cfg['save_epoch'] 10 | min_avg_loss = 5000 11 | date = cfg['train_date'] 12 | 13 | # 模型 14 | model = U_net() 15 | 16 | # 读入权重 17 | start_epoch = 0 18 | 19 | # 加载初始模型 20 | if cfg['use_old_pkl'] is True: 21 | model.load_state_dict(torch.load(os.path.join('..', 'weights', cfg['old_pkl']))) 22 | print('模型加载完成') 23 | 24 | model.to(cfg['device']) 25 | model.summary(model) 26 | 27 | start_epoch += 1 28 | 29 | # 数据仓库` 30 | dataset = Dataset(os.path.join('..', 'data', cfg['train_date'])) 31 | 32 | train_data_loader = torch.utils.data.DataLoader(dataset=dataset, 33 | batch_size=cfg['batch_size'], 34 | shuffle=True) 35 | 36 | # 优化器 37 | loss_F = torch.nn.MSELoss() 38 | loss_F.to(cfg['device']) 39 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg['learning_rate']) 40 | 41 | # 添加学习率衰减 42 | 43 | max_loss_name = '' 44 | for epoch in range(start_epoch, cfg['epochs'], 1): 45 | # model.train() 46 | total_loss = 0.0 47 | min_loss = 1000 48 | max_loss = 0.001 49 | # 按批次取文件 50 | for index, (x, y, img_name) in enumerate(train_data_loader): 51 | img = x.to(cfg['device']) 52 | label = y.to(cfg['device']) 53 | 54 | # print('------------') 55 | # print(img.shape) 56 | # print('------------') 57 | 58 | pre = model(img) # 前向传播 59 | # 计算损失反向传播 60 | # print(pre.shape) 61 | # print('------------') 62 | # print(label.shape) 63 | 64 | loss = loss_F(pre, label) # 计算损失 65 | optimizer.zero_grad() # 因为每次反向传播的时候,变量里面的梯度都要清零 66 | loss.backward() # 变量得到了grad 67 | optimizer.step() # 更新参数 68 | total_loss += loss.item() 69 | 70 | if loss < min_loss: 71 | min_loss = loss 72 | 73 | if loss > max_loss: 74 | max_loss = loss 75 | max_loss_name = img_name[0] 76 | 77 | # if (index+1) % 5 == 0: 78 | # print('Epoch %d loss %f' % (epoch, total_loss / (index + 1))) 79 | 80 | avg_loss = total_loss/(index+1) 81 | 82 | if avg_loss < min_avg_loss: 83 | min_avg_loss = avg_loss 84 | torch.save(model.state_dict(), os.path.join('..', "weights", 'min_loss.pth')) 85 | # if cfg['pytorch_version'] is False: 86 | # torch.save(model.state_dict(), os.path.join('..', "weights", 'old_version_min_loss.pth'), _use_new_zipfile_serialization=cfg['pytorch_version']) 87 | 88 | print('Epoch %d, photo number %d,avg loss %f, min loss %f, max loss %f, max loss name %s,min avg loss %f' % (epoch, index+1, avg_loss, min_loss, max_loss, max_loss_name, min_avg_loss)) 89 | 90 | print('-------------------') 91 | 92 | # 跑完save_epoch个epoch保存权重 93 | # save_name = "epoch_" + str(epoch).zfill(3) + ".pth" 94 | # if (epoch) % save_epoch == 0: 95 | # torch.save(model.state_dict(), os.path.join('..', "weights", save_name)) 96 | 97 | 98 | if __name__ == "__main__": 99 | # 训练 100 | train() 101 | -------------------------------------------------------------------------------- /heatmap/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.2.2 2 | numpy==1.21.6 3 | opencv_python==4.5.4.60 4 | Pillow==9.2.0 5 | scipy==1.7.3 6 | tensorboardX==2.5.1 7 | tensorwatch==0.9.1 8 | torch==1.8.1 9 | torchsummaryX==1.3.0 10 | torchvision==0.11.3 11 | torchviz==0.0.2 12 | -------------------------------------------------------------------------------- /heatmap/result/08_11/test/202108110759_6_keypoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/heatmap/result/08_11/test/202108110759_6_keypoint.jpg -------------------------------------------------------------------------------- /offset/README.md: -------------------------------------------------------------------------------- 1 | # Keypoint-Detection/offset 2 | 基于自己标注的工业图像的关键点检测,每张图片标注了4个关键点 3 | 4 | ## 主要内容 5 | * 这部分主要是个人第一次做目标检测方面的任务,用于练手和理解网络 6 | * 网络采用的是U-net下采样部分 7 | * 标签构建采用的Coordinate方法,损失函数仅采用了坐标点之间的距离平方和 8 | 9 | 10 | ## 模型效果 11 | 其中绿色点为标注的关键点,红色为预测的关键点 12 | 网络效果
13 | 14 | -------------------------------------------------------------------------------- /offset/data/08_11/test/imgs/202107170976_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/offset/data/08_11/test/imgs/202107170976_4.jpg -------------------------------------------------------------------------------- /offset/main/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | config = { 5 | # 网络训练部分 6 | 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), 7 | 'batch_size': 1, 8 | 'epochs': 510, 9 | 'save_epoch': 100, 10 | 11 | # 设置路径部分 12 | 'train_date': 'all_out', 13 | 'train_way': 'train', 14 | 'test_date': 'all_in', 15 | 'test_way': 'test', 16 | 17 | # 调用的模型 18 | 'pkl_file': 'min_loss_all_in.pth' 19 | 20 | } 21 | -------------------------------------------------------------------------------- /offset/main/data_pre.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from config import config as cfg 6 | import cv2 7 | 8 | 9 | # json变成加入高斯的np 10 | def json_to_numpy(dataset_path): 11 | # test of train_flag 12 | # way = 'test' 13 | # date = '07_23' 14 | 15 | # 开始处理 16 | for name in os.listdir(imgs_path): 17 | # 读入label 18 | with open(os.path.join(os.path.join(labels_path), 19 | name.split('.')[0] + '.json'), 'r', encoding='utf8')as fp: 20 | json_data = json.load(fp) 21 | points = json_data['shapes'] 22 | 23 | # print(points) 24 | landmarks = [] 25 | for point in points: 26 | for p in point['points'][0]: 27 | landmarks.append(p) 28 | 29 | # print(landmarks) 30 | landmarks = np.array(landmarks) 31 | 32 | # 保存为np 33 | # np.save(os.path.join(save_path, name.split('.')[0] + '.npy'), landmarks) 34 | 35 | return landmarks 36 | 37 | 38 | def one_json_to_numpy(dataset_path): 39 | with open(dataset_path) as fp: 40 | json_data = json.load(fp) 41 | points = json_data['shapes'] 42 | 43 | # print(points) 44 | landmarks = [] 45 | for point in points: 46 | for p in point['points'][0]: 47 | landmarks.append(p) 48 | 49 | # print(landmarks) 50 | landmarks = np.array(landmarks) 51 | 52 | # 保存为np 53 | # np.save(os.path.join(save_path, name.split('.')[0] + '.npy'), landmarks) 54 | 55 | return landmarks 56 | 57 | 58 | if __name__ == '__main__': 59 | json_to_numpy() 60 | -------------------------------------------------------------------------------- /offset/main/determine_rotation_angle.py: -------------------------------------------------------------------------------- 1 | import os import json import math import numpy as np from numpy.linalg import det # 判断三个点确认的圆的圆心和半径 def points2circle(p1, p2, p3): p1 = np.array(p1) p2 = np.array(p2) p3 = np.array(p3) num1 = len(p1) num2 = len(p2) num3 = len(p3) # 输入检查 if (num1 == num2) and (num2 == num3): if num1 == 2: p1 = np.append(p1, 0) p2 = np.append(p2, 0) p3 = np.append(p3, 0) elif num1 != 3: print('\t仅支持二维或三维坐标输入') return None else: print('\t输入坐标的维数不一致') return None # 共线检查 temp01 = p1 - p2 temp02 = p3 - p2 temp03 = np.cross(temp01, temp02) temp = (temp03 @ temp03) / (temp01 @ temp01) / (temp02 @ temp02) if temp < 10**-6: print('\t三点共线, 无法确定圆') return None temp1 = np.vstack((p1, p2, p3)) temp2 = np.ones(3).reshape(3, 1) mat1 = np.hstack((temp1, temp2)) # size = 3x4 m = +det(mat1[:, 1:]) n = -det(np.delete(mat1, 1, axis=1)) p = +det(np.delete(mat1, 2, axis=1)) q = -det(temp1) temp3 = np.array([p1 @ p1, p2 @ p2, p3 @ p3]).reshape(3, 1) temp4 = np.hstack((temp3, mat1)) temp5 = np.array([2 * q, -m, -n, -p, 0]) mat2 = np.vstack((temp4, temp5)) # size = 4x5 A = +det(mat2[:, 1:]) B = -det(np.delete(mat2, 1, axis=1)) C = +det(np.delete(mat2, 2, axis=1)) D = -det(np.delete(mat2, 3, axis=1)) E = +det(mat2[:, :-1]) pc = -np.array([B, C, D]) / 2 / A r = np.sqrt(B * B + C * C + D * D - 4 * A * E) / 2 / abs(A) return pc, r def calculate_distance(point_1, point_2): dis_x = abs(point_1[0] - point_2[0]) dis_y = abs(point_1[1] - point_2[1]) distance = (dis_x ** 2 + dis_y ** 2) ** 0.5 return distance def determine_point_in_circle(p, pc, r): distance = calculate_distance(p, pc) if distance < r: return True else: return False def find_left_and_right_point(point_1, point_2): if point_1[0] < point_2[0]: return point_1, point_2 elif point_1[0] > point_2[0]: return point_2, point_1 def find_center_point(keypoints): assert len(keypoints) == 4 pc_0, r_0 = points2circle(keypoints[1], keypoints[2], keypoints[3]) pc_1, r_1 = points2circle(keypoints[0], keypoints[2], keypoints[3]) pc_2, r_2 = points2circle(keypoints[0], keypoints[1], keypoints[3]) pc_3, r_3 = points2circle(keypoints[0], keypoints[1], keypoints[2]) if determine_point_in_circle(keypoints[0], pc_0, r_0): return 0 elif determine_point_in_circle(keypoints[1], pc_1, r_1): return 1 elif determine_point_in_circle(keypoints[2], pc_2, r_2): return 2 elif determine_point_in_circle(keypoints[3], pc_3, r_3): return 3 def get_two_point_center_point(p_1, p_2): p = [] p.append((p_1[0] + p_2[0]) / 2) p.append((p_1[1] + p_2[1]) / 2) return p def find_two_point_by_distance(c_point, center_id, keypoints): keypoints.pop(center_id) assert len(keypoints) == 3 id_list = [[0, 1], [0, 2], [1, 2]] center_0_1 = get_two_point_center_point(keypoints[0], keypoints[1]) center_0_2 = get_two_point_center_point(keypoints[0], keypoints[2]) center_1_2 = get_two_point_center_point(keypoints[1], keypoints[2]) distance_c01_cp = calculate_distance(center_0_1, c_point) distance_c02_cp = calculate_distance(center_0_2, c_point) distance_c12_cp = calculate_distance(center_1_2, c_point) distance_list = [distance_c01_cp, distance_c02_cp, distance_c12_cp] sorted_id = sorted(range(len(distance_list)), key=lambda k: distance_list[k], reverse=True) near_id = id_list[sorted_id[-1]] return keypoints[near_id[0]], keypoints[near_id[1]] def find_two_points_by_angle(c_point, center_id, keypoints): keypoints.pop(center_id) assert len(keypoints) == 3 id_list = [[0, 1], [0, 2], [1, 2]] angle_0_1 = angle(c_point + keypoints[0], c_point + keypoints[1]) angle_0_2 = angle(c_point + keypoints[0], c_point + keypoints[2]) angle_1_2 = angle(c_point + keypoints[1], c_point + keypoints[2]) angle_list = [angle_0_1, angle_0_2, angle_1_2] sorted_id = sorted(range(len(angle_list)), key=lambda k: angle_list[k], reverse=True) near_id = id_list[sorted_id[0]] return keypoints[near_id[0]], keypoints[near_id[1]] def determine_direction(point_center, point): center_x = point_center[0] center_y = point_center[0] x = point[0] y = point[1] if center_x < x and center_y > y: return 'northeast' elif center_x < x and center_y < y: return 'southeast' elif center_x > x and center_y < y: return 'southwest' elif center_x > x and center_y > y: return 'northwest' def angle(v1, v2): dx1 = v1[2] - v1[0] dy1 = v1[3] - v1[1] dx2 = v2[2] - v2[0] dy2 = v2[3] - v2[1] angle1 = math.atan2(dy1, dx1) angle1 = float(angle1 * 180 / math.pi) # print(angle1) angle2 = math.atan2(dy2, dx2) angle2 = float(angle2 * 180 / math.pi) # print(angle2) if angle1 * angle2 >= 0: included_angle = abs(angle1 - angle2) else: included_angle = abs(angle1) + abs(angle2) if included_angle > 180: included_angle = 360 - included_angle return included_angle def calculate_rotation_angle(keypoints): # 判断四个点中哪个点位于三个点的中间 center_id = find_center_point(keypoints) point_center = keypoints[center_id] # 判断筒边缘的两个点 # the_one, the_other = find_two_point_by_distance(point_center, center_id, keypoints) the_one, the_other = find_two_points_by_angle(point_center, center_id, keypoints) point_left, point_right = find_left_and_right_point(the_one, the_other) # 判断目前筒正前方朝向 point_mid = [] point_mid.append((point_left[0] + point_right[0]) / 2) point_mid.append((point_left[1] + point_right[1]) / 2) direction = determine_direction(point_center, point_mid) if direction == 'northeast' or direction == 'northwest': turn = 'clockwise' elif direction == 'southeast' or direction == 'southwest': turn = 'anticlockwise' # print(point_left, point_right, sep=' ') # 确定夹角,v1为y轴向量 v1 = [0.0, 0.0, 0.0, -1.0] # v2 = point_left + point_right if turn == 'clockwise': v2 = point_right + point_left elif turn == 'anticlockwise': v2 = point_left + point_right the_angle = angle(v1, v2) return the_angle, turn if __name__ == '__main__': the_id = '2' re_dir = 'json_' + the_id files = os.listdir(re_dir) flag = 1 for file in files: re_json_dir = os.path.join(re_dir, file) with open(re_json_dir, 'r', encoding='utf8')as fp: json_data = json.load(fp) # print(json_data) image_name = json_data['imagePath'] point_shapes = json_data['shapes'] key_points = [] for point in point_shapes: key_points.append(point['points'][0]) print(image_name) print(key_points) turn_angle, turn_direction = calculate_rotation_angle(key_points) print(turn_angle, turn_direction, sep=' ') print('-----------------------------\n') # write_line = image_name + ',' + str(turn_angle) + ',' + turn_direction # f = open('image_angle.txt', 'a') # if flag == 1: # f.write(str(write_line)) # flag = 0 # f.write('\n' + str(write_line)) # f.close() -------------------------------------------------------------------------------- /offset/main/models.py: -------------------------------------------------------------------------------- 1 | from torchsummaryX import summary 2 | from net_util import * 3 | 4 | 5 | # Unet的下采样模块,两次卷积 6 | class DoubleConv(nn.Module): 7 | 8 | def __init__(self, in_channels, out_channels, channel_reduce=False): # 只是定义网络中需要用到的方法 9 | super(DoubleConv, self).__init__() 10 | 11 | # 通道减少的系数 12 | coefficient = 2 if channel_reduce else 1 13 | 14 | self.down = nn.Sequential( 15 | nn.Conv2d(in_channels, coefficient * out_channels, kernel_size=(3, 3), padding=1), 16 | # nn.BatchNorm2d(coefficient * out_channels), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(coefficient * out_channels, out_channels, kernel_size=(3, 3), padding=1), 19 | # nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True) 21 | ) 22 | 23 | def forward(self, x): 24 | return self.down(x) 25 | 26 | 27 | # 上采样(转置卷积加残差链接) 28 | class Up(nn.Module): 29 | 30 | # 千万注意输入,in_channels是要送入二次卷积的channel,out_channels是二次卷积之后的channel 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | # 先上采样特征图 34 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=4, stride=2, padding=1) 35 | self.conv = DoubleConv(in_channels, out_channels, channel_reduce=True) 36 | 37 | def forward(self, x1, x2): 38 | x1 = self.up(x1) 39 | x = torch.cat([x1, x2], dim=1) 40 | x = self.conv(x) 41 | return x 42 | 43 | 44 | # simple U-net模型 45 | class U_net(nn.Module): 46 | 47 | def __init__(self): # 只是定义网络中需要用到的方法 48 | super(U_net, self).__init__() 49 | 50 | # 下采样 51 | self.double_conv1 = DoubleConv(3, 32) 52 | self.double_conv2 = DoubleConv(32, 64) 53 | self.double_conv3 = DoubleConv(64, 128) 54 | self.double_conv4 = DoubleConv(128, 256) 55 | self.double_conv5 = DoubleConv(256, 256) 56 | 57 | # 上采样之前采用回归坐标的方式 58 | self.conv1 = nn.Conv2d(256, 64, kernel_size=(1, 1), padding=0) 59 | self.conv2 = nn.Conv2d(64, 16, kernel_size=(1, 1), padding=0) 60 | self.fc1 = nn.Linear(11264, 128) 61 | self.fc2 = nn.Linear(128, 8) 62 | 63 | # 上采样 64 | self.up1 = Up(512, 128) 65 | self.up2 = Up(256, 64) 66 | self.up3 = Up(128, 32) 67 | self.up4 = Up(64, 16) 68 | 69 | # 最后一层 70 | # self.conv = nn.Conv2d(16, 1, kernel_size=(1, 1), padding=0) 71 | # self.fc1 = nn.Linear(180224, 1024) 72 | # self.fc2 = nn.Linear(1024, 8) 73 | 74 | def forward(self, x): 75 | # down 76 | # print(x.shape) 77 | c1 = self.double_conv1(x) # (,32,512,512) 78 | p1 = nn.MaxPool2d(2)(c1) # (,32,256,256) 79 | c2 = self.double_conv2(p1) # (,64,256,256) 80 | p2 = nn.MaxPool2d(2)(c2) # (,64,128,128) 81 | c3 = self.double_conv3(p2) # (,128,128,128) 82 | p3 = nn.MaxPool2d(2)(c3) # (,128,64,64) 83 | c4 = self.double_conv4(p3) # (,256,64,64) 84 | p4 = nn.MaxPool2d(2)(c4) # (,256,32,32) 85 | c5 = self.double_conv5(p4) # (,256,32,32) 86 | # 最后一次卷积不做池化操作 87 | 88 | # up 89 | # u1 = self.up1(c5, c4) # (,128,64,64) 90 | # u2 = self.up2(u1, c3) # (,64,128,128) 91 | # u3 = self.up3(u2, c2) # (,32,256,256) 92 | # u4 = self.up4(u3, c1) # (,16,512,512) 93 | 94 | # 最后一层,隐射到3个特征图 95 | x1 = self.conv1(c5) 96 | x2 = self.conv2(x1) 97 | # print(x1.shape) 98 | x2 = x2.view(x2.size(0), -1) 99 | 100 | # print(x1.shape) 101 | x = self.fc1(x2) 102 | out = self.fc2(x) 103 | 104 | return out 105 | 106 | def summary(self, net): 107 | x = torch.rand(cfg['batch_size'], 3, 352, 512) # 352*512 108 | # 送入设备 109 | x = x.to(cfg['device']) 110 | # 输出y的shape 111 | # print(net(x).shape) 112 | 113 | # 展示网络结构 114 | summary(net, x) 115 | 116 | 117 | # 主函数调试 118 | if __name__ == "__main__": 119 | m = U_net().to(cfg['device']) 120 | m.summary(m) 121 | -------------------------------------------------------------------------------- /offset/main/net_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from torch import nn 5 | import torchvision 6 | from config import config as cfg 7 | import torch.utils.data 8 | from torchvision import datasets, transforms, models 9 | import cv2 10 | from data_pre import one_json_to_numpy 11 | 12 | 13 | # box_3D的数据仓库 14 | class Dataset(torch.utils.data.Dataset): 15 | # 初始化 16 | def __init__(self, dataset_path): 17 | self.dataset_path = dataset_path 18 | self.img_name_list = os.listdir(os.path.join(dataset_path, 'imgs')) 19 | 20 | # 根据 index 返回位置的图像和label 21 | def __getitem__(self, index): 22 | # 先处理img 23 | img_path = os.path.join(self.dataset_path, 'imgs', self.img_name_list[index]) 24 | img = cv2.imread(img_path) 25 | img = cv2.resize(img, (512, 352)) 26 | img = transforms.ToTensor()(img) 27 | 28 | # 读入标签 29 | label_path = os.path.join(self.dataset_path, 'labels', self.img_name_list[index].split('.')[0]+'.json') 30 | mask = one_json_to_numpy(label_path) 31 | # mask = np.load(os.path.join(self.dataset_path, 'masks', self.img_name_list[index].split('.')[0] + '.npy')) 32 | mask = torch.tensor(mask, dtype=torch.float32) 33 | 34 | # print(img_path) 35 | # print(label_path) 36 | # print('-----------------') 37 | if img_path.split('.')[0] != label_path.split('.')[0]: 38 | print("数据不一致") 39 | 40 | return img, mask 41 | 42 | # 数据集的大小 43 | def __len__(self): 44 | return len(self.img_name_list) 45 | -------------------------------------------------------------------------------- /offset/main/test_main.py: -------------------------------------------------------------------------------- 1 | from net_util import * 2 | import os 3 | from models import * 4 | import cv2 5 | import xml.etree.cElementTree as ET 6 | from xml.etree import ElementTree 7 | import numpy as np 8 | from xml.dom import minidom 9 | import time 10 | from determine_rotation_angle import calculate_rotation_angle 11 | from data_pre import one_json_to_numpy 12 | 13 | 14 | def show_point_on_picture(img, landmarks, landmarks_gt): 15 | for point in landmarks[0]: 16 | point = tuple([int(point[0]), int(point[1])]) 17 | # print(point) 18 | img = cv2.circle(img, center=point, radius=20, color=(0, 0, 255), thickness=-1) 19 | for point in landmarks_gt: 20 | point = tuple([int(point[0]), int(point[1])]) 21 | # print(point) 22 | img = cv2.circle(img, center=point, radius=20, color=(0, 255, 0), thickness=-1) 23 | return img 24 | 25 | 26 | # 预测 27 | def evaluate(): 28 | date = cfg['test_date'] 29 | way = cfg['test_way'] 30 | 31 | # 测试路径 32 | img_path = os.path.join('..', 'data', date, way, 'imgs') 33 | # 测试集坐标 34 | label_path = os.path.join('..', 'data', date, way, 'labels') 35 | 36 | # 定义模型 37 | model = U_net() 38 | 39 | # 034, 128 40 | model.load_state_dict(torch.load(os.path.join('..', 'weights', cfg['pkl_file']))) 41 | model.to(cfg['device']) 42 | # model.summary(model) 43 | model.eval() 44 | 45 | # 下采样模型 46 | # resize = torch.nn.Upsample(scale_factor=(1, 0.5), mode='bilinear', align_corners=True) 47 | 48 | diff_angle_list = [] 49 | total_loss = 0 50 | # 开始预测 51 | for index, name in enumerate(os.listdir(img_path)): 52 | print(name+": "+str(index+1)) 53 | 54 | img = cv2.imread(os.path.join(img_path, name)) 55 | img = cv2.resize(img, (512, 352)) 56 | img = transforms.ToTensor()(img) 57 | img = torch.unsqueeze(img, dim=0) # 训练时采用的是DataLoader函数, 会直接增加第一个维度 58 | print(img.shape) 59 | 60 | # 喂入网络 61 | img = img.to(cfg['device']) 62 | 63 | pre = model(img) 64 | pre = pre.cpu().detach().numpy() 65 | pre = pre.reshape(pre.shape[0], -1, 2) 66 | 67 | gt_point = one_json_to_numpy(os.path.join(label_path, name.split('.')[0] + '.json')) 68 | gt_point = gt_point.reshape(-1, 2) 69 | 70 | print(os.path.join(img_path, name)) 71 | print(os.path.join(label_path, name.split('.')[0] + '.json')) 72 | 73 | print(pre) 74 | print(gt_point) 75 | 76 | pre_label = torch.Tensor(pre.reshape(1, -1)).to(cfg['device']) 77 | label = torch.Tensor(gt_point.reshape(1, -1)).to(cfg['device']) 78 | 79 | loss_F = torch.nn.MSELoss() 80 | loss_F.to(cfg['device']) 81 | loss = loss_F(pre_label, label) # 计算损失 82 | 83 | print('坐标误差损失: ', loss.item()) 84 | total_loss += loss.item() 85 | print('---------') 86 | 87 | # del img 88 | # 89 | # img = cv2.imread(os.path.join(img_path, name)) 90 | # print(img.shape) 91 | # img = show_point_on_picture(img, pre, gt_point) 92 | # 93 | # # 存储绘制图像部分 94 | # save_dir = os.path.join('..', 'result', date, way + '_data') 95 | # if not os.path.exists(save_dir): 96 | # os.makedirs(save_dir) 97 | # 98 | # result_dir = os.path.join(save_dir, name.split('.')[0] + '_keypoint.jpg') 99 | # print(result_dir) 100 | # cv2.imwrite(result_dir, img) 101 | # 102 | # # 通过预测的关键点计算旋转角度 103 | # keypoints = pre[0].tolist() 104 | # pre_angle, pre_turn = calculate_rotation_angle(keypoints) 105 | # print(pre_angle, pre_turn, sep=' ') 106 | # 107 | # # 通过真实的关键点计算旋转角度 108 | # keypoints = gt_point.tolist() 109 | # true_angle, true_turn = calculate_rotation_angle(keypoints) 110 | # print(true_angle, true_turn, sep=' ') 111 | # 112 | # if pre_angle < 0: 113 | # diff_angle = -100000 114 | # else: 115 | # diff_angle = abs(true_angle - pre_angle) 116 | # print(diff_angle) 117 | # 118 | # diff_angle_list.append(diff_angle) 119 | # print('=========') 120 | # 121 | # print(np.mean(diff_angle_list)) 122 | 123 | print(total_loss / (index + 1)) 124 | 125 | 126 | if __name__ == "__main__": 127 | # choose weights 128 | # choose_weights() 129 | 130 | # 对一组权重进行预测 131 | evaluate() 132 | -------------------------------------------------------------------------------- /offset/main/train_main.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | from config import config as cfg 3 | 4 | 5 | # 训练主函数入口 6 | def train(): 7 | print('start') 8 | 9 | save_epoch = cfg['save_epoch'] 10 | min_avg_loss = 5000 11 | date = cfg['train_date'] 12 | # 模型 13 | 14 | model = U_net() 15 | # 读入权重 16 | start_epoch = 0 17 | # model.load_state_dict(torch.load(os.path.join('..', 'weights', 'epoch_001.pkl'))) 18 | model.to(cfg['device']) 19 | model.summary(model) 20 | 21 | start_epoch += 1 22 | 23 | # 数据仓库` 24 | dataset = Dataset(os.path.join('..', 'data', date, 'train')) 25 | 26 | train_data_loader = torch.utils.data.DataLoader(dataset=dataset, 27 | batch_size=cfg['batch_size'], 28 | shuffle=True) 29 | 30 | # 优化器 31 | loss_F = torch.nn.MSELoss() 32 | loss_F.to(cfg['device']) 33 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) 34 | 35 | for epoch in range(start_epoch, cfg['epochs'], 1): 36 | # model.train() 37 | total_loss = 0.0 38 | min_loss = 10000 39 | max_loss = 10 40 | # 按批次取文件 41 | for index, (x, y) in enumerate(train_data_loader): 42 | img = x.to(cfg['device']) 43 | label = y.to(cfg['device']) 44 | 45 | # print('------------') 46 | # print(img.shape) 47 | # print('------------') 48 | 49 | pre = model(img) # 前向传播 50 | # 计算损失反向传播 51 | # print(pre.shape) 52 | # print('------------') 53 | # print(label.shape) 54 | 55 | loss = loss_F(pre, label) # 计算损失 56 | optimizer.zero_grad() # 因为每次反向传播的时候,变量里面的梯度都要清零 57 | loss.backward() # 变量得到了grad 58 | optimizer.step() # 更新参数 59 | total_loss += loss.item() 60 | 61 | if loss < min_loss: 62 | min_loss = loss 63 | 64 | if loss > max_loss: 65 | max_loss = loss 66 | 67 | # if (index+1) % 5 == 0: 68 | # print('Epoch %d loss %f' % (epoch, total_loss / (index + 1))) 69 | 70 | avg_loss = total_loss/(index+1) 71 | 72 | if avg_loss < min_avg_loss: 73 | print(pre) 74 | print(label) 75 | print('-------------------') 76 | min_avg_loss = avg_loss 77 | torch.save(model.state_dict(), os.path.join('..', "weights", 'min_loss.pth')) 78 | 79 | print('Epoch %d, photo number %d, avg loss %f, min loss %f, max loss %f, min avg loss %f' % (epoch, index+1, avg_loss, min_loss, max_loss, min_avg_loss)) 80 | 81 | print('========================') 82 | 83 | # 跑完save_epoch个epoch保存权重 84 | save_name = "epoch_" + str(epoch).zfill(3) + ".pth" 85 | if (epoch) % save_epoch == 0: 86 | torch.save(model.state_dict(), os.path.join('..', "weights", save_name)) 87 | 88 | 89 | if __name__ == "__main__": 90 | # 训练 91 | train() 92 | -------------------------------------------------------------------------------- /offset/result/08_11/test_data/202107170976_4_keypoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/offset/result/08_11/test_data/202107170976_4_keypoint.jpg --------------------------------------------------------------------------------