├── README.md
├── heatmap
├── README.md
├── data
│ └── 08_11
│ │ └── test
│ │ ├── imgs
│ │ └── 202108110759_6.jpg
│ │ └── labels
│ │ └── 202108110759_6.json
├── main
│ ├── config_loc.py
│ ├── data_pre_loc.py
│ ├── models_loc.py
│ ├── net_util_loc.py
│ ├── test_main.py
│ └── train_main.py
├── requirements.txt
└── result
│ └── 08_11
│ └── test
│ └── 202108110759_6_keypoint.jpg
└── offset
├── README.md
├── data
└── 08_11
│ └── test
│ ├── imgs
│ └── 202107170976_4.jpg
│ └── labels
│ └── 202107170976_4.json
├── main
├── config.py
├── data_pre.py
├── determine_rotation_angle.py
├── models.py
├── net_util.py
├── test_main.py
└── train_main.py
└── result
└── 08_11
└── test_data
└── 202107170976_4_keypoint.jpg
/README.md:
--------------------------------------------------------------------------------
1 | # 代码说明
2 |
3 | 1. 其中 offset 文件夹是很早期研究者们采用的方法,直接通过神经网络输出点坐标的方式
4 | - 这种方法存在的弊端就是预测难度很高,除了欧式距离的值没有其他任何参考
5 |
6 | 2. 另一个 heatmap 文件夹则是现在的研究者们更常用的方法,将关键点的点转换成一个heatmap类型的数据
7 | - 这种方法可以去得更好的效果,在工业检测任务上的效果图见 heatmap 文件夹中的 Readme
8 |
--------------------------------------------------------------------------------
/heatmap/README.md:
--------------------------------------------------------------------------------
1 | ## KeyPoint-Detection/heatmap
2 | 采用高斯图作为网络的输出结果,模型预测的关键点精度进一步提高,误差控制在几个像素
3 | 在data文件夹中给出了一张实例图片和标注的json文件
4 |
5 | ## 主要内容
6 | * 使用了高斯图作为label,网络的输出结果为对同等维度的高斯图
7 | * 网络模型采用的是 U-net
8 | * 损失函数使用的是 torch.nn.MSELoss()
9 | * 在网络预测的高斯图中采用最大值点作为预测的关键点位置
10 |
11 | ## 模型效果
12 | 其中绿色点为标注的关键点,红色点为预测的关键点
13 | 
14 |
--------------------------------------------------------------------------------
/heatmap/data/08_11/test/imgs/202108110759_6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/heatmap/data/08_11/test/imgs/202108110759_6.jpg
--------------------------------------------------------------------------------
/heatmap/main/config_loc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | config_loc = {
5 | # 网络训练部分
6 | # 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
7 | 'device': torch.device("cuda"),
8 | 'batch_size': 1,
9 | 'epochs': 510,
10 | 'save_epoch': 100,
11 | 'learning_rate': 0.0001,
12 | 'lr_scheduler': 'step1', # 可以选择'step1','step2'梯度下降,'exponential'指数下降
13 |
14 | # 原图尺寸
15 | 'img_h': 3036,
16 | 'img_w': 4024,
17 |
18 | # 裁剪后的尺寸
19 | 'cut_h': 3008,
20 | 'cut_w': 3968,
21 |
22 | # 网络输入的图像尺寸
23 | 'input_h': 752,
24 | 'input_w': 992,
25 |
26 | # 高斯核大小
27 | 'gauss_h': 51,
28 | 'gauss_w': 51,
29 |
30 | # 关键点个数
31 | 'kpt_n': 3,
32 |
33 | # 网络评估部分
34 | 'test_batch_size': 1,
35 | 'test_threshold': 0.5,
36 |
37 | 'path': '/home/lwm/Disk_D/Projects/IndustrialProjects/Haier',
38 |
39 | # 设置路径部分
40 | 'train_date': 'all_9_20220219',
41 | 'train_way': 'train',
42 | 'test_date': 'all_9_20220219',
43 | 'test_way': 'train',
44 |
45 | # 调用的模型
46 | # 'pkl_file': '20211123_2.pth',
47 | 'pkl_file': 'all_9_0216.pth',
48 |
49 | # 是否加载预训练模型
50 | 'use_old_pkl': True,
51 | 'old_pkl': 'min_loss.pth',
52 |
53 | # # pytorch < 1.6
54 | # 'pytorch_version': False,
55 |
56 | # remember location
57 | 'start_x': 200,
58 | 'start_y': 200,
59 | 'start_angle': 0,
60 |
61 | # max x,y
62 | 'max_x': 300,
63 | 'max_y': 250,
64 | 'max_angle': 90,
65 |
66 | # min x,y
67 | 'min_x': 100,
68 | 'min_y': 100,
69 |
70 | # key points relative location
71 | 'distance_12': 360,
72 | 'distance_13': 200,
73 | 'distance_23': 410,
74 |
75 | 'delta': 50,
76 |
77 | # 'photo_to_world':
78 | }
79 |
--------------------------------------------------------------------------------
/heatmap/main/data_pre_loc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import math
5 | import matplotlib.pyplot as plt
6 | from config_loc import config_loc as cfg
7 | from scipy.ndimage import gaussian_filter
8 | import cv2
9 | import PIL
10 | from torchvision import transforms
11 | import torch
12 | from PIL import Image, ImageFont, ImageDraw
13 | # os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
14 |
15 |
16 | # json变成加入高斯的np
17 | def json_to_numpy(dataset_path):
18 | with open(dataset_path) as fp:
19 | json_data = json.load(fp)
20 | points = json_data['shapes']
21 |
22 | # print(points)
23 | landmarks = []
24 | for point in points:
25 | for p in point['points']:
26 | landmarks.append(p)
27 |
28 | # print(landmarks)
29 | landmarks = np.array(landmarks)
30 | landmarks = landmarks.reshape(-1, 2)
31 |
32 | # 保存为np
33 | # np.save(os.path.join(save_path, name.split('.')[0] + '.npy'), landmarks)
34 |
35 | return landmarks
36 |
37 |
38 | def generate_heatmaps(landmarks, height, width, sigma):
39 |
40 | # img_h = cfg['img_h']
41 | # img_w = cfg['img_w']
42 |
43 | cut_h = cfg['cut_h']
44 | cut_w = cfg['cut_w']
45 |
46 | heatmaps = []
47 | for points in landmarks:
48 | heatmap = np.zeros((height, width))
49 |
50 | # ch = int(height * points[1] / img_h)
51 | # cw = int(width * points[0] / img_w)
52 | ch = int(height * points[1] / cut_h)
53 | cw = int(width * points[0] / cut_w)
54 | heatmap[ch][cw] = 1
55 |
56 | heatmap = cv2.GaussianBlur(heatmap, sigma, 0)
57 | am = np.amax(heatmap)
58 | heatmap /= am / 255
59 | heatmaps.append(heatmap)
60 |
61 | heatmaps = np.array(heatmaps)
62 | # heatmaps = np.expand_dims(heatmaps, axis=0)
63 |
64 | return heatmaps
65 |
66 |
67 | def show_heatmap(heatmaps):
68 | for heatmap in heatmaps:
69 | plt.imshow(heatmap, cmap='hot', interpolation='nearest')
70 | plt.show()
71 |
72 |
73 | def heatmap_to_point(heatmaps):
74 | # img_h = cfg['img_h']
75 | # img_w = cfg['img_w']
76 | cut_h = cfg['cut_h']
77 | cut_w = cfg['cut_w']
78 | input_h = cfg['input_h']
79 | input_w = cfg['input_w']
80 |
81 | points = []
82 | for heatmap in heatmaps:
83 | pos = np.unravel_index(np.argmax(heatmap), heatmap.shape)
84 | point0 = cut_w * (pos[0] / input_w)
85 | point1 = cut_h * (pos[1] / input_h)
86 | points.append([point1, point0])
87 | return np.array(points)
88 |
89 |
90 | def show_inputImg_and_keypointLabel(imgPath, heatmaps):
91 | points = []
92 | for heatmap in heatmaps:
93 | pos = np.unravel_index(np.argmax(heatmap), heatmap.shape)
94 | points.append([pos[1], pos[0]])
95 |
96 | img = PIL.Image.open(imgPath).convert('RGB')
97 | img = transforms.ToTensor()(img) # 3*3000*4096
98 | img = img[:, :cfg['cut_h'], :cfg['cut_w']]
99 |
100 | img = img.unsqueeze(0) # 增加一维
101 | resize = torch.nn.Upsample(scale_factor=(0.25, 0.25), mode='bilinear', align_corners=True)
102 | img = resize(img)
103 |
104 | img = img.squeeze(0) # 减少一维
105 |
106 | print(img.shape)
107 |
108 | img = transforms.ToPILImage()(img)
109 | draw = ImageDraw.Draw(img)
110 | for point in points:
111 | print(point)
112 | draw.point((point[0], point[1]), fill='yellow')
113 |
114 | # 保存
115 | img.save(os.path.join('..','show', 'out.jpg'))
116 |
117 |
118 | if __name__ == '__main__':
119 | landmarks = json_to_numpy('../data/0828_back/labels/202108280005_2.json')
120 | print('关键点坐标', landmarks, '-------------', sep='\n')
121 |
122 | heatmaps = generate_heatmaps(landmarks, cfg['input_h'], cfg['input_w'], (cfg['gauss_h'], cfg['gauss_w']))
123 | # print(heatmaps)
124 | print(heatmaps.shape)
125 |
126 | # show heatmap picture
127 | # show_heatmap(heatmaps)
128 |
129 | # show cut image and the keypoints
130 | # show_inputImg_and_keypointLabel('../data/08_11_in/train/imgs/202108110556_6.jpg', heatmaps)
131 |
--------------------------------------------------------------------------------
/heatmap/main/models_loc.py:
--------------------------------------------------------------------------------
1 | from torchsummaryX import summary
2 | from net_util_loc import *
3 | from torchviz import make_dot
4 | import tensorwatch as tw
5 | from tensorboardX import SummaryWriter
6 |
7 |
8 | # Unet的下采样模块,两次卷积
9 | class DoubleConv(nn.Module):
10 |
11 | def __init__(self, in_channels, out_channels, channel_reduce=False): # 只是定义网络中需要用到的方法
12 | super(DoubleConv, self).__init__()
13 |
14 | # 通道减少的系数
15 | coefficient = 2 if channel_reduce else 1
16 |
17 | self.down = nn.Sequential(
18 | nn.Conv2d(in_channels, coefficient * out_channels, kernel_size=(3, 3), padding=1),
19 | nn.BatchNorm2d(coefficient * out_channels),
20 | nn.ReLU(inplace=True),
21 | nn.Conv2d(coefficient * out_channels, out_channels, kernel_size=(3, 3), padding=1),
22 | nn.BatchNorm2d(out_channels),
23 | nn.ReLU(inplace=True)
24 | )
25 |
26 | def forward(self, x):
27 | return self.down(x)
28 |
29 |
30 | # 上采样(转置卷积加残差链接)
31 | class Up(nn.Module):
32 |
33 | # 千万注意输入,in_channels是要送入二次卷积的channel,out_channels是二次卷积之后的channel
34 | def __init__(self, in_channels, out_channels):
35 | super().__init__()
36 | # 先上采样特征图
37 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=4, stride=2, padding=1)
38 | self.conv = DoubleConv(in_channels, out_channels, channel_reduce=True)
39 |
40 | def forward(self, x1, x2):
41 | x1 = self.up(x1)
42 |
43 | # print(x1.shape, x2.shape)
44 | x = torch.cat([x1, x2], dim=1)
45 | x = self.conv(x)
46 | return x
47 |
48 |
49 | # simple U-net模型
50 | class U_net(nn.Module):
51 |
52 | def __init__(self): # 只是定义网络中需要用到的方法
53 | super(U_net, self).__init__()
54 |
55 | # 下采样
56 | self.double_conv1 = DoubleConv(3, 32)
57 | self.double_conv2 = DoubleConv(32, 64)
58 | self.double_conv3 = DoubleConv(64, 128)
59 | self.double_conv4 = DoubleConv(128, 256)
60 | self.double_conv5 = DoubleConv(256, 256)
61 |
62 | # 上采样
63 | self.up1 = Up(512, 128)
64 | self.up2 = Up(256, 64)
65 | self.up3 = Up(128, 32)
66 | self.up4 = Up(64, 16)
67 |
68 | # 最后一层
69 | self.out = nn.Conv2d(16, cfg['kpt_n'], kernel_size=(1, 1), padding=0)
70 |
71 | def forward(self, x):
72 | # down
73 | c1 = self.double_conv1(x) # (,32,512,512)
74 | p1 = nn.MaxPool2d(2)(c1) # (,32,256,256)
75 | c2 = self.double_conv2(p1) # (,64,256,256)
76 | p2 = nn.MaxPool2d(2)(c2) # (,64,128,128)
77 | c3 = self.double_conv3(p2) # (,128,128,128)
78 | p3 = nn.MaxPool2d(2)(c3) # (,128,64,64)
79 | c4 = self.double_conv4(p3) # (,256,64,64)
80 | p4 = nn.MaxPool2d(2)(c4) # (,256,32,32)
81 | c5 = self.double_conv5(p4) # (,256,32,32)
82 | # 最后一次卷积不做池化操作
83 |
84 | # up
85 | u1 = self.up1(c5, c4) # (,128,64,64)
86 | u2 = self.up2(u1, c3) # (,64,128,128)
87 | u3 = self.up3(u2, c2) # (,32,256,256)
88 | u4 = self.up4(u3, c1) # (,16,512,512)
89 |
90 | # 最后一层,隐射到3个特征图
91 | out = self.out(u4)
92 |
93 | return out
94 |
95 | def summary(self, net):
96 | x = torch.rand(cfg['batch_size'], 3, cfg['input_h'], cfg['input_w']) # 352*512
97 | # 送入设备
98 | x = x.to(cfg['device'])
99 | # 输出y的shape
100 | # print(net(x).shape)
101 |
102 | # 展示网络结构
103 | summary(net, x)
104 |
105 |
106 | # 主函数调试
107 | if __name__ == "__main__":
108 | m = U_net().to(cfg['device'])
109 |
110 | m.summary(m)
111 | img = tw.draw_model(m, [cfg['batch_size'], 3, cfg['input_h'], cfg['input_w']])
112 | # # print(img)
113 | img.save('/home/mlg1504/bolt_project/algorithm/keypoint_det/code2/U-net.png')
114 |
115 | x = torch.rand(cfg['batch_size'], 3, cfg['input_h'], cfg['input_w'])
116 | x = x.to(cfg['device'])
117 | # model = U_net()
118 | # with SummaryWriter(comment='U-Net') as w:
119 | # w.add_graph(m, x)
120 |
121 | # y = m(x)
122 | #
123 | # g = make_dot(y)
124 | #
125 | # g.render('espnet_model', view=False)
126 |
127 |
--------------------------------------------------------------------------------
/heatmap/main/net_util_loc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from torch import nn
5 | import torchvision
6 | from config_loc import config_loc as cfg
7 | import torch.utils.data
8 | from torchvision import datasets, transforms, models
9 | import cv2
10 | import PIL
11 | from PIL import Image, ImageFont, ImageDraw
12 | from data_pre_loc import json_to_numpy, generate_heatmaps
13 |
14 |
15 | # box_3D的数据仓库
16 | class Dataset(torch.utils.data.Dataset):
17 | # 初始化
18 | def __init__(self, dataset_path):
19 | self.dataset_path = dataset_path
20 | self.img_name_list = os.listdir(os.path.join(dataset_path, cfg['train_way'], 'imgs'))
21 |
22 | # 根据 index 返回位置的图像和label
23 | def __getitem__(self, index):
24 | # 先处理img
25 | img_name = self.img_name_list[index]
26 | img = PIL.Image.open(os.path.join(self.dataset_path, cfg['train_way'], 'imgs', img_name)).convert('RGB')
27 | img = transforms.ToTensor()(img) # 3*3000*4096
28 | img = img[:, :cfg['cut_h'], :cfg['cut_w']]
29 |
30 | img = img.unsqueeze(0) # 增加一维
31 | resize = torch.nn.Upsample(scale_factor=(0.25, 0.25), mode='bilinear', align_corners=True)
32 | img = resize(img).squeeze(0) #
33 | # print(img.shape)
34 |
35 | # 读入标签
36 | mask_name = img_name.split('.')[0] + '.json'
37 | mask = json_to_numpy(os.path.join(self.dataset_path, cfg['train_way'], 'labels', mask_name))
38 | # mask = np.load(os.path.join(self.dataset_path, 'labels', self.img_name_list[index].split('.')[0] + '.json'),allow_pickle=True)
39 | # mask = torch.tensor(mask, dtype=torch.float32)
40 |
41 | heatmaps = generate_heatmaps(mask, cfg['input_h'], cfg['input_w'], (cfg['gauss_h'], cfg['gauss_w']))
42 | heatmaps = torch.tensor(heatmaps, dtype=torch.float32)
43 |
44 | return img, heatmaps, img_name
45 |
46 | # 数据集的大小
47 | def __len__(self):
48 | return len(self.img_name_list)
49 |
--------------------------------------------------------------------------------
/heatmap/main/test_main.py:
--------------------------------------------------------------------------------
1 | from net_util_loc import *
2 | import os
3 | from models_loc import *
4 | import cv2
5 | import PIL
6 | import xml.etree.cElementTree as ET
7 | from xml.etree import ElementTree
8 | import numpy as np
9 | from xml.dom import minidom
10 | import time
11 | # from determine_rotation_angle_loc import calculate_rotation_angle
12 | from determine_location_loc import determine_location
13 | from data_pre_loc import json_to_numpy, generate_heatmaps, heatmap_to_point
14 |
15 |
16 | def show_point_on_picture(img, landmarks, landmarks_gt):
17 | for point in landmarks:
18 | point = tuple([int(point[0]), int(point[1])])
19 | # print(point)
20 | img = cv2.circle(img, center=point, radius=20, color=(0, 0, 255), thickness=-1)
21 | for point in landmarks_gt:
22 | point = tuple([int(point[0]), int(point[1])])
23 | # print(point)
24 | img = cv2.circle(img, center=point, radius=20, color=(0, 255, 0), thickness=-1)
25 | return img
26 |
27 |
28 | # 预测
29 | def evaluate(flag=False):
30 | date = cfg['test_date']
31 | way = cfg['test_way']
32 |
33 | # 测试路径
34 | img_path = os.path.join('..', 'data', date, way, 'imgs')
35 | # 测试集坐标
36 | label_path = os.path.join('..', 'data', date, way, 'labels')
37 |
38 | # 定义模型
39 | model = U_net()
40 |
41 | # 034, 128
42 | model.load_state_dict(torch.load(os.path.join('..', 'weights', cfg['pkl_file'])))
43 | model.to(cfg['device'])
44 | # model.summary(model)
45 | model.eval()
46 |
47 | # 下采样模型
48 | resize = torch.nn.Upsample(scale_factor=(1, 0.5), mode='bilinear', align_corners=True)
49 |
50 | total_loss = 0
51 | diff_angle_list = []
52 | diff_delta_x_list = []
53 | diff_delta_y_list = []
54 | max_keypoint_diff = 0
55 |
56 | # 开始预测
57 | for index, name in enumerate(os.listdir(img_path)):
58 | print('图像名称:', name+" 图像编号:"+str(index+1))
59 |
60 | # img = cv2.imread(os.path.join(img_path, name))
61 | # img = cv2.resize(img, (cfg['input_w'], cfg['input_h']))
62 | # img = transforms.ToTensor()(img)
63 | # img = torch.unsqueeze(img, dim=0) # 训练时采用的是DataLoader函数, 会直接增加第一个维度
64 |
65 | img = PIL.Image.open(os.path.join(img_path, name)).convert('RGB')
66 | img = transforms.ToTensor()(img) # 3*3000*4096
67 | img = img[:, :cfg['cut_h'], :cfg['cut_w']]
68 |
69 | img = img.unsqueeze(0) # 增加一维
70 | resize = torch.nn.Upsample(scale_factor=(0.25, 0.25), mode='bilinear', align_corners=True)
71 | img = resize(img)
72 |
73 | print('输入网络的图片维度信息:', img.shape)
74 |
75 | # 喂入网络
76 | img = img.to(cfg['device'])
77 |
78 | pre = model(img)
79 | pre = pre.cpu().detach().numpy()
80 | # pre = pre.reshape(pre.shape[0], -1, 2)
81 |
82 | pre_point = heatmap_to_point(pre[0])
83 |
84 | point = json_to_numpy(os.path.join(label_path, name.split('.')[0] + '.json'))
85 |
86 | print('图片的jpg文件位置:', os.path.join(img_path, name))
87 | print('标注的json文件位置:', os.path.join(label_path, name.split('.')[0] + '.json'))
88 |
89 | print('预测的关键点坐标:\n', pre_point)
90 | print('真实的关键点坐标:\n', point)
91 |
92 | pre_label = torch.Tensor(pre_point.reshape(1, -1)).to(cfg['device'])
93 | label = torch.Tensor(point.reshape(1, -1)).to(cfg['device'])
94 |
95 | loss_F = torch.nn.MSELoss()
96 | loss_F.to(cfg['device'])
97 | loss = loss_F(pre_label, label) # 计算损失
98 |
99 | print('+++坐标误差损失: ', loss.item())
100 | total_loss += loss.item()
101 |
102 | if loss.item() > max_keypoint_diff:
103 | max_keypoint_diff = loss.item()
104 |
105 | print('---------')
106 |
107 | if flag == True:
108 | del img
109 |
110 | img = cv2.imread(os.path.join(img_path, name))
111 | img = show_point_on_picture(img, pre_point, point)
112 |
113 | # 存储绘制图像部分
114 | save_dir = os.path.join('..', 'result', date, way + '_data')
115 | if not os.path.exists(save_dir):
116 | os.makedirs(save_dir)
117 |
118 | result_dir = os.path.join(save_dir, name.split('.')[0] + '_keypoint.jpg')
119 | print('绘制关键点后图像的存储位置:', result_dir)
120 | cv2.imwrite(result_dir, img)
121 |
122 |
123 | print('##################')
124 | print('# ---- Mean ---- #')
125 | print('##################')
126 |
127 | print('平均每个关键点坐标误差:', total_loss / (index + 1), ' 最大单个关键点坐标误差:', max_keypoint_diff)
128 |
129 |
130 | if __name__ == "__main__":
131 | # choose weights
132 | # choose_weights()
133 |
134 | # 对一组权重进行预测
135 | evaluate(flag=True)
136 |
--------------------------------------------------------------------------------
/heatmap/main/train_main.py:
--------------------------------------------------------------------------------
1 | from models_loc import *
2 | from config_loc import config_loc as cfg
3 |
4 |
5 | # 训练主函数入口
6 | def train():
7 | print('start')
8 |
9 | save_epoch = cfg['save_epoch']
10 | min_avg_loss = 5000
11 | date = cfg['train_date']
12 |
13 | # 模型
14 | model = U_net()
15 |
16 | # 读入权重
17 | start_epoch = 0
18 |
19 | # 加载初始模型
20 | if cfg['use_old_pkl'] is True:
21 | model.load_state_dict(torch.load(os.path.join('..', 'weights', cfg['old_pkl'])))
22 | print('模型加载完成')
23 |
24 | model.to(cfg['device'])
25 | model.summary(model)
26 |
27 | start_epoch += 1
28 |
29 | # 数据仓库`
30 | dataset = Dataset(os.path.join('..', 'data', cfg['train_date']))
31 |
32 | train_data_loader = torch.utils.data.DataLoader(dataset=dataset,
33 | batch_size=cfg['batch_size'],
34 | shuffle=True)
35 |
36 | # 优化器
37 | loss_F = torch.nn.MSELoss()
38 | loss_F.to(cfg['device'])
39 | optimizer = torch.optim.Adam(model.parameters(), lr=cfg['learning_rate'])
40 |
41 | # 添加学习率衰减
42 |
43 | max_loss_name = ''
44 | for epoch in range(start_epoch, cfg['epochs'], 1):
45 | # model.train()
46 | total_loss = 0.0
47 | min_loss = 1000
48 | max_loss = 0.001
49 | # 按批次取文件
50 | for index, (x, y, img_name) in enumerate(train_data_loader):
51 | img = x.to(cfg['device'])
52 | label = y.to(cfg['device'])
53 |
54 | # print('------------')
55 | # print(img.shape)
56 | # print('------------')
57 |
58 | pre = model(img) # 前向传播
59 | # 计算损失反向传播
60 | # print(pre.shape)
61 | # print('------------')
62 | # print(label.shape)
63 |
64 | loss = loss_F(pre, label) # 计算损失
65 | optimizer.zero_grad() # 因为每次反向传播的时候,变量里面的梯度都要清零
66 | loss.backward() # 变量得到了grad
67 | optimizer.step() # 更新参数
68 | total_loss += loss.item()
69 |
70 | if loss < min_loss:
71 | min_loss = loss
72 |
73 | if loss > max_loss:
74 | max_loss = loss
75 | max_loss_name = img_name[0]
76 |
77 | # if (index+1) % 5 == 0:
78 | # print('Epoch %d loss %f' % (epoch, total_loss / (index + 1)))
79 |
80 | avg_loss = total_loss/(index+1)
81 |
82 | if avg_loss < min_avg_loss:
83 | min_avg_loss = avg_loss
84 | torch.save(model.state_dict(), os.path.join('..', "weights", 'min_loss.pth'))
85 | # if cfg['pytorch_version'] is False:
86 | # torch.save(model.state_dict(), os.path.join('..', "weights", 'old_version_min_loss.pth'), _use_new_zipfile_serialization=cfg['pytorch_version'])
87 |
88 | print('Epoch %d, photo number %d,avg loss %f, min loss %f, max loss %f, max loss name %s,min avg loss %f' % (epoch, index+1, avg_loss, min_loss, max_loss, max_loss_name, min_avg_loss))
89 |
90 | print('-------------------')
91 |
92 | # 跑完save_epoch个epoch保存权重
93 | # save_name = "epoch_" + str(epoch).zfill(3) + ".pth"
94 | # if (epoch) % save_epoch == 0:
95 | # torch.save(model.state_dict(), os.path.join('..', "weights", save_name))
96 |
97 |
98 | if __name__ == "__main__":
99 | # 训练
100 | train()
101 |
--------------------------------------------------------------------------------
/heatmap/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.2.2
2 | numpy==1.21.6
3 | opencv_python==4.5.4.60
4 | Pillow==9.2.0
5 | scipy==1.7.3
6 | tensorboardX==2.5.1
7 | tensorwatch==0.9.1
8 | torch==1.8.1
9 | torchsummaryX==1.3.0
10 | torchvision==0.11.3
11 | torchviz==0.0.2
12 |
--------------------------------------------------------------------------------
/heatmap/result/08_11/test/202108110759_6_keypoint.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/heatmap/result/08_11/test/202108110759_6_keypoint.jpg
--------------------------------------------------------------------------------
/offset/README.md:
--------------------------------------------------------------------------------
1 | # Keypoint-Detection/offset
2 | 基于自己标注的工业图像的关键点检测,每张图片标注了4个关键点
3 |
4 | ## 主要内容
5 | * 这部分主要是个人第一次做目标检测方面的任务,用于练手和理解网络
6 | * 网络采用的是U-net下采样部分
7 | * 标签构建采用的Coordinate方法,损失函数仅采用了坐标点之间的距离平方和
8 |
9 |
10 | ## 模型效果
11 | 其中绿色点为标注的关键点,红色为预测的关键点
12 | 
13 |
14 |
--------------------------------------------------------------------------------
/offset/data/08_11/test/imgs/202107170976_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/offset/data/08_11/test/imgs/202107170976_4.jpg
--------------------------------------------------------------------------------
/offset/main/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | config = {
5 | # 网络训练部分
6 | 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"),
7 | 'batch_size': 1,
8 | 'epochs': 510,
9 | 'save_epoch': 100,
10 |
11 | # 设置路径部分
12 | 'train_date': 'all_out',
13 | 'train_way': 'train',
14 | 'test_date': 'all_in',
15 | 'test_way': 'test',
16 |
17 | # 调用的模型
18 | 'pkl_file': 'min_loss_all_in.pth'
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/offset/main/data_pre.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from config import config as cfg
6 | import cv2
7 |
8 |
9 | # json变成加入高斯的np
10 | def json_to_numpy(dataset_path):
11 | # test of train_flag
12 | # way = 'test'
13 | # date = '07_23'
14 |
15 | # 开始处理
16 | for name in os.listdir(imgs_path):
17 | # 读入label
18 | with open(os.path.join(os.path.join(labels_path),
19 | name.split('.')[0] + '.json'), 'r', encoding='utf8')as fp:
20 | json_data = json.load(fp)
21 | points = json_data['shapes']
22 |
23 | # print(points)
24 | landmarks = []
25 | for point in points:
26 | for p in point['points'][0]:
27 | landmarks.append(p)
28 |
29 | # print(landmarks)
30 | landmarks = np.array(landmarks)
31 |
32 | # 保存为np
33 | # np.save(os.path.join(save_path, name.split('.')[0] + '.npy'), landmarks)
34 |
35 | return landmarks
36 |
37 |
38 | def one_json_to_numpy(dataset_path):
39 | with open(dataset_path) as fp:
40 | json_data = json.load(fp)
41 | points = json_data['shapes']
42 |
43 | # print(points)
44 | landmarks = []
45 | for point in points:
46 | for p in point['points'][0]:
47 | landmarks.append(p)
48 |
49 | # print(landmarks)
50 | landmarks = np.array(landmarks)
51 |
52 | # 保存为np
53 | # np.save(os.path.join(save_path, name.split('.')[0] + '.npy'), landmarks)
54 |
55 | return landmarks
56 |
57 |
58 | if __name__ == '__main__':
59 | json_to_numpy()
60 |
--------------------------------------------------------------------------------
/offset/main/determine_rotation_angle.py:
--------------------------------------------------------------------------------
1 | import os
import json
import math
import numpy as np
from numpy.linalg import det
# 判断三个点确认的圆的圆心和半径
def points2circle(p1, p2, p3):
p1 = np.array(p1)
p2 = np.array(p2)
p3 = np.array(p3)
num1 = len(p1)
num2 = len(p2)
num3 = len(p3)
# 输入检查
if (num1 == num2) and (num2 == num3):
if num1 == 2:
p1 = np.append(p1, 0)
p2 = np.append(p2, 0)
p3 = np.append(p3, 0)
elif num1 != 3:
print('\t仅支持二维或三维坐标输入')
return None
else:
print('\t输入坐标的维数不一致')
return None
# 共线检查
temp01 = p1 - p2
temp02 = p3 - p2
temp03 = np.cross(temp01, temp02)
temp = (temp03 @ temp03) / (temp01 @ temp01) / (temp02 @ temp02)
if temp < 10**-6:
print('\t三点共线, 无法确定圆')
return None
temp1 = np.vstack((p1, p2, p3))
temp2 = np.ones(3).reshape(3, 1)
mat1 = np.hstack((temp1, temp2)) # size = 3x4
m = +det(mat1[:, 1:])
n = -det(np.delete(mat1, 1, axis=1))
p = +det(np.delete(mat1, 2, axis=1))
q = -det(temp1)
temp3 = np.array([p1 @ p1, p2 @ p2, p3 @ p3]).reshape(3, 1)
temp4 = np.hstack((temp3, mat1))
temp5 = np.array([2 * q, -m, -n, -p, 0])
mat2 = np.vstack((temp4, temp5)) # size = 4x5
A = +det(mat2[:, 1:])
B = -det(np.delete(mat2, 1, axis=1))
C = +det(np.delete(mat2, 2, axis=1))
D = -det(np.delete(mat2, 3, axis=1))
E = +det(mat2[:, :-1])
pc = -np.array([B, C, D]) / 2 / A
r = np.sqrt(B * B + C * C + D * D - 4 * A * E) / 2 / abs(A)
return pc, r
def calculate_distance(point_1, point_2):
dis_x = abs(point_1[0] - point_2[0])
dis_y = abs(point_1[1] - point_2[1])
distance = (dis_x ** 2 + dis_y ** 2) ** 0.5
return distance
def determine_point_in_circle(p, pc, r):
distance = calculate_distance(p, pc)
if distance < r:
return True
else:
return False
def find_left_and_right_point(point_1, point_2):
if point_1[0] < point_2[0]:
return point_1, point_2
elif point_1[0] > point_2[0]:
return point_2, point_1
def find_center_point(keypoints):
assert len(keypoints) == 4
pc_0, r_0 = points2circle(keypoints[1], keypoints[2], keypoints[3])
pc_1, r_1 = points2circle(keypoints[0], keypoints[2], keypoints[3])
pc_2, r_2 = points2circle(keypoints[0], keypoints[1], keypoints[3])
pc_3, r_3 = points2circle(keypoints[0], keypoints[1], keypoints[2])
if determine_point_in_circle(keypoints[0], pc_0, r_0):
return 0
elif determine_point_in_circle(keypoints[1], pc_1, r_1):
return 1
elif determine_point_in_circle(keypoints[2], pc_2, r_2):
return 2
elif determine_point_in_circle(keypoints[3], pc_3, r_3):
return 3
def get_two_point_center_point(p_1, p_2):
p = []
p.append((p_1[0] + p_2[0]) / 2)
p.append((p_1[1] + p_2[1]) / 2)
return p
def find_two_point_by_distance(c_point, center_id, keypoints):
keypoints.pop(center_id)
assert len(keypoints) == 3
id_list = [[0, 1], [0, 2], [1, 2]]
center_0_1 = get_two_point_center_point(keypoints[0], keypoints[1])
center_0_2 = get_two_point_center_point(keypoints[0], keypoints[2])
center_1_2 = get_two_point_center_point(keypoints[1], keypoints[2])
distance_c01_cp = calculate_distance(center_0_1, c_point)
distance_c02_cp = calculate_distance(center_0_2, c_point)
distance_c12_cp = calculate_distance(center_1_2, c_point)
distance_list = [distance_c01_cp, distance_c02_cp, distance_c12_cp]
sorted_id = sorted(range(len(distance_list)), key=lambda k: distance_list[k], reverse=True)
near_id = id_list[sorted_id[-1]]
return keypoints[near_id[0]], keypoints[near_id[1]]
def find_two_points_by_angle(c_point, center_id, keypoints):
keypoints.pop(center_id)
assert len(keypoints) == 3
id_list = [[0, 1], [0, 2], [1, 2]]
angle_0_1 = angle(c_point + keypoints[0], c_point + keypoints[1])
angle_0_2 = angle(c_point + keypoints[0], c_point + keypoints[2])
angle_1_2 = angle(c_point + keypoints[1], c_point + keypoints[2])
angle_list = [angle_0_1, angle_0_2, angle_1_2]
sorted_id = sorted(range(len(angle_list)), key=lambda k: angle_list[k], reverse=True)
near_id = id_list[sorted_id[0]]
return keypoints[near_id[0]], keypoints[near_id[1]]
def determine_direction(point_center, point):
center_x = point_center[0]
center_y = point_center[0]
x = point[0]
y = point[1]
if center_x < x and center_y > y:
return 'northeast'
elif center_x < x and center_y < y:
return 'southeast'
elif center_x > x and center_y < y:
return 'southwest'
elif center_x > x and center_y > y:
return 'northwest'
def angle(v1, v2):
dx1 = v1[2] - v1[0]
dy1 = v1[3] - v1[1]
dx2 = v2[2] - v2[0]
dy2 = v2[3] - v2[1]
angle1 = math.atan2(dy1, dx1)
angle1 = float(angle1 * 180 / math.pi)
# print(angle1)
angle2 = math.atan2(dy2, dx2)
angle2 = float(angle2 * 180 / math.pi)
# print(angle2)
if angle1 * angle2 >= 0:
included_angle = abs(angle1 - angle2)
else:
included_angle = abs(angle1) + abs(angle2)
if included_angle > 180:
included_angle = 360 - included_angle
return included_angle
def calculate_rotation_angle(keypoints):
# 判断四个点中哪个点位于三个点的中间
center_id = find_center_point(keypoints)
point_center = keypoints[center_id]
# 判断筒边缘的两个点
# the_one, the_other = find_two_point_by_distance(point_center, center_id, keypoints)
the_one, the_other = find_two_points_by_angle(point_center, center_id, keypoints)
point_left, point_right = find_left_and_right_point(the_one, the_other)
# 判断目前筒正前方朝向
point_mid = []
point_mid.append((point_left[0] + point_right[0]) / 2)
point_mid.append((point_left[1] + point_right[1]) / 2)
direction = determine_direction(point_center, point_mid)
if direction == 'northeast' or direction == 'northwest':
turn = 'clockwise'
elif direction == 'southeast' or direction == 'southwest':
turn = 'anticlockwise'
# print(point_left, point_right, sep=' ')
# 确定夹角,v1为y轴向量
v1 = [0.0, 0.0, 0.0, -1.0]
# v2 = point_left + point_right
if turn == 'clockwise':
v2 = point_right + point_left
elif turn == 'anticlockwise':
v2 = point_left + point_right
the_angle = angle(v1, v2)
return the_angle, turn
if __name__ == '__main__':
the_id = '2'
re_dir = 'json_' + the_id
files = os.listdir(re_dir)
flag = 1
for file in files:
re_json_dir = os.path.join(re_dir, file)
with open(re_json_dir, 'r', encoding='utf8')as fp:
json_data = json.load(fp)
# print(json_data)
image_name = json_data['imagePath']
point_shapes = json_data['shapes']
key_points = []
for point in point_shapes:
key_points.append(point['points'][0])
print(image_name)
print(key_points)
turn_angle, turn_direction = calculate_rotation_angle(key_points)
print(turn_angle, turn_direction, sep=' ')
print('-----------------------------\n')
# write_line = image_name + ',' + str(turn_angle) + ',' + turn_direction
# f = open('image_angle.txt', 'a')
# if flag == 1:
# f.write(str(write_line))
# flag = 0
# f.write('\n' + str(write_line))
# f.close()
--------------------------------------------------------------------------------
/offset/main/models.py:
--------------------------------------------------------------------------------
1 | from torchsummaryX import summary
2 | from net_util import *
3 |
4 |
5 | # Unet的下采样模块,两次卷积
6 | class DoubleConv(nn.Module):
7 |
8 | def __init__(self, in_channels, out_channels, channel_reduce=False): # 只是定义网络中需要用到的方法
9 | super(DoubleConv, self).__init__()
10 |
11 | # 通道减少的系数
12 | coefficient = 2 if channel_reduce else 1
13 |
14 | self.down = nn.Sequential(
15 | nn.Conv2d(in_channels, coefficient * out_channels, kernel_size=(3, 3), padding=1),
16 | # nn.BatchNorm2d(coefficient * out_channels),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(coefficient * out_channels, out_channels, kernel_size=(3, 3), padding=1),
19 | # nn.BatchNorm2d(out_channels),
20 | nn.ReLU(inplace=True)
21 | )
22 |
23 | def forward(self, x):
24 | return self.down(x)
25 |
26 |
27 | # 上采样(转置卷积加残差链接)
28 | class Up(nn.Module):
29 |
30 | # 千万注意输入,in_channels是要送入二次卷积的channel,out_channels是二次卷积之后的channel
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | # 先上采样特征图
34 | self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=4, stride=2, padding=1)
35 | self.conv = DoubleConv(in_channels, out_channels, channel_reduce=True)
36 |
37 | def forward(self, x1, x2):
38 | x1 = self.up(x1)
39 | x = torch.cat([x1, x2], dim=1)
40 | x = self.conv(x)
41 | return x
42 |
43 |
44 | # simple U-net模型
45 | class U_net(nn.Module):
46 |
47 | def __init__(self): # 只是定义网络中需要用到的方法
48 | super(U_net, self).__init__()
49 |
50 | # 下采样
51 | self.double_conv1 = DoubleConv(3, 32)
52 | self.double_conv2 = DoubleConv(32, 64)
53 | self.double_conv3 = DoubleConv(64, 128)
54 | self.double_conv4 = DoubleConv(128, 256)
55 | self.double_conv5 = DoubleConv(256, 256)
56 |
57 | # 上采样之前采用回归坐标的方式
58 | self.conv1 = nn.Conv2d(256, 64, kernel_size=(1, 1), padding=0)
59 | self.conv2 = nn.Conv2d(64, 16, kernel_size=(1, 1), padding=0)
60 | self.fc1 = nn.Linear(11264, 128)
61 | self.fc2 = nn.Linear(128, 8)
62 |
63 | # 上采样
64 | self.up1 = Up(512, 128)
65 | self.up2 = Up(256, 64)
66 | self.up3 = Up(128, 32)
67 | self.up4 = Up(64, 16)
68 |
69 | # 最后一层
70 | # self.conv = nn.Conv2d(16, 1, kernel_size=(1, 1), padding=0)
71 | # self.fc1 = nn.Linear(180224, 1024)
72 | # self.fc2 = nn.Linear(1024, 8)
73 |
74 | def forward(self, x):
75 | # down
76 | # print(x.shape)
77 | c1 = self.double_conv1(x) # (,32,512,512)
78 | p1 = nn.MaxPool2d(2)(c1) # (,32,256,256)
79 | c2 = self.double_conv2(p1) # (,64,256,256)
80 | p2 = nn.MaxPool2d(2)(c2) # (,64,128,128)
81 | c3 = self.double_conv3(p2) # (,128,128,128)
82 | p3 = nn.MaxPool2d(2)(c3) # (,128,64,64)
83 | c4 = self.double_conv4(p3) # (,256,64,64)
84 | p4 = nn.MaxPool2d(2)(c4) # (,256,32,32)
85 | c5 = self.double_conv5(p4) # (,256,32,32)
86 | # 最后一次卷积不做池化操作
87 |
88 | # up
89 | # u1 = self.up1(c5, c4) # (,128,64,64)
90 | # u2 = self.up2(u1, c3) # (,64,128,128)
91 | # u3 = self.up3(u2, c2) # (,32,256,256)
92 | # u4 = self.up4(u3, c1) # (,16,512,512)
93 |
94 | # 最后一层,隐射到3个特征图
95 | x1 = self.conv1(c5)
96 | x2 = self.conv2(x1)
97 | # print(x1.shape)
98 | x2 = x2.view(x2.size(0), -1)
99 |
100 | # print(x1.shape)
101 | x = self.fc1(x2)
102 | out = self.fc2(x)
103 |
104 | return out
105 |
106 | def summary(self, net):
107 | x = torch.rand(cfg['batch_size'], 3, 352, 512) # 352*512
108 | # 送入设备
109 | x = x.to(cfg['device'])
110 | # 输出y的shape
111 | # print(net(x).shape)
112 |
113 | # 展示网络结构
114 | summary(net, x)
115 |
116 |
117 | # 主函数调试
118 | if __name__ == "__main__":
119 | m = U_net().to(cfg['device'])
120 | m.summary(m)
121 |
--------------------------------------------------------------------------------
/offset/main/net_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import numpy as np
4 | from torch import nn
5 | import torchvision
6 | from config import config as cfg
7 | import torch.utils.data
8 | from torchvision import datasets, transforms, models
9 | import cv2
10 | from data_pre import one_json_to_numpy
11 |
12 |
13 | # box_3D的数据仓库
14 | class Dataset(torch.utils.data.Dataset):
15 | # 初始化
16 | def __init__(self, dataset_path):
17 | self.dataset_path = dataset_path
18 | self.img_name_list = os.listdir(os.path.join(dataset_path, 'imgs'))
19 |
20 | # 根据 index 返回位置的图像和label
21 | def __getitem__(self, index):
22 | # 先处理img
23 | img_path = os.path.join(self.dataset_path, 'imgs', self.img_name_list[index])
24 | img = cv2.imread(img_path)
25 | img = cv2.resize(img, (512, 352))
26 | img = transforms.ToTensor()(img)
27 |
28 | # 读入标签
29 | label_path = os.path.join(self.dataset_path, 'labels', self.img_name_list[index].split('.')[0]+'.json')
30 | mask = one_json_to_numpy(label_path)
31 | # mask = np.load(os.path.join(self.dataset_path, 'masks', self.img_name_list[index].split('.')[0] + '.npy'))
32 | mask = torch.tensor(mask, dtype=torch.float32)
33 |
34 | # print(img_path)
35 | # print(label_path)
36 | # print('-----------------')
37 | if img_path.split('.')[0] != label_path.split('.')[0]:
38 | print("数据不一致")
39 |
40 | return img, mask
41 |
42 | # 数据集的大小
43 | def __len__(self):
44 | return len(self.img_name_list)
45 |
--------------------------------------------------------------------------------
/offset/main/test_main.py:
--------------------------------------------------------------------------------
1 | from net_util import *
2 | import os
3 | from models import *
4 | import cv2
5 | import xml.etree.cElementTree as ET
6 | from xml.etree import ElementTree
7 | import numpy as np
8 | from xml.dom import minidom
9 | import time
10 | from determine_rotation_angle import calculate_rotation_angle
11 | from data_pre import one_json_to_numpy
12 |
13 |
14 | def show_point_on_picture(img, landmarks, landmarks_gt):
15 | for point in landmarks[0]:
16 | point = tuple([int(point[0]), int(point[1])])
17 | # print(point)
18 | img = cv2.circle(img, center=point, radius=20, color=(0, 0, 255), thickness=-1)
19 | for point in landmarks_gt:
20 | point = tuple([int(point[0]), int(point[1])])
21 | # print(point)
22 | img = cv2.circle(img, center=point, radius=20, color=(0, 255, 0), thickness=-1)
23 | return img
24 |
25 |
26 | # 预测
27 | def evaluate():
28 | date = cfg['test_date']
29 | way = cfg['test_way']
30 |
31 | # 测试路径
32 | img_path = os.path.join('..', 'data', date, way, 'imgs')
33 | # 测试集坐标
34 | label_path = os.path.join('..', 'data', date, way, 'labels')
35 |
36 | # 定义模型
37 | model = U_net()
38 |
39 | # 034, 128
40 | model.load_state_dict(torch.load(os.path.join('..', 'weights', cfg['pkl_file'])))
41 | model.to(cfg['device'])
42 | # model.summary(model)
43 | model.eval()
44 |
45 | # 下采样模型
46 | # resize = torch.nn.Upsample(scale_factor=(1, 0.5), mode='bilinear', align_corners=True)
47 |
48 | diff_angle_list = []
49 | total_loss = 0
50 | # 开始预测
51 | for index, name in enumerate(os.listdir(img_path)):
52 | print(name+": "+str(index+1))
53 |
54 | img = cv2.imread(os.path.join(img_path, name))
55 | img = cv2.resize(img, (512, 352))
56 | img = transforms.ToTensor()(img)
57 | img = torch.unsqueeze(img, dim=0) # 训练时采用的是DataLoader函数, 会直接增加第一个维度
58 | print(img.shape)
59 |
60 | # 喂入网络
61 | img = img.to(cfg['device'])
62 |
63 | pre = model(img)
64 | pre = pre.cpu().detach().numpy()
65 | pre = pre.reshape(pre.shape[0], -1, 2)
66 |
67 | gt_point = one_json_to_numpy(os.path.join(label_path, name.split('.')[0] + '.json'))
68 | gt_point = gt_point.reshape(-1, 2)
69 |
70 | print(os.path.join(img_path, name))
71 | print(os.path.join(label_path, name.split('.')[0] + '.json'))
72 |
73 | print(pre)
74 | print(gt_point)
75 |
76 | pre_label = torch.Tensor(pre.reshape(1, -1)).to(cfg['device'])
77 | label = torch.Tensor(gt_point.reshape(1, -1)).to(cfg['device'])
78 |
79 | loss_F = torch.nn.MSELoss()
80 | loss_F.to(cfg['device'])
81 | loss = loss_F(pre_label, label) # 计算损失
82 |
83 | print('坐标误差损失: ', loss.item())
84 | total_loss += loss.item()
85 | print('---------')
86 |
87 | # del img
88 | #
89 | # img = cv2.imread(os.path.join(img_path, name))
90 | # print(img.shape)
91 | # img = show_point_on_picture(img, pre, gt_point)
92 | #
93 | # # 存储绘制图像部分
94 | # save_dir = os.path.join('..', 'result', date, way + '_data')
95 | # if not os.path.exists(save_dir):
96 | # os.makedirs(save_dir)
97 | #
98 | # result_dir = os.path.join(save_dir, name.split('.')[0] + '_keypoint.jpg')
99 | # print(result_dir)
100 | # cv2.imwrite(result_dir, img)
101 | #
102 | # # 通过预测的关键点计算旋转角度
103 | # keypoints = pre[0].tolist()
104 | # pre_angle, pre_turn = calculate_rotation_angle(keypoints)
105 | # print(pre_angle, pre_turn, sep=' ')
106 | #
107 | # # 通过真实的关键点计算旋转角度
108 | # keypoints = gt_point.tolist()
109 | # true_angle, true_turn = calculate_rotation_angle(keypoints)
110 | # print(true_angle, true_turn, sep=' ')
111 | #
112 | # if pre_angle < 0:
113 | # diff_angle = -100000
114 | # else:
115 | # diff_angle = abs(true_angle - pre_angle)
116 | # print(diff_angle)
117 | #
118 | # diff_angle_list.append(diff_angle)
119 | # print('=========')
120 | #
121 | # print(np.mean(diff_angle_list))
122 |
123 | print(total_loss / (index + 1))
124 |
125 |
126 | if __name__ == "__main__":
127 | # choose weights
128 | # choose_weights()
129 |
130 | # 对一组权重进行预测
131 | evaluate()
132 |
--------------------------------------------------------------------------------
/offset/main/train_main.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | from config import config as cfg
3 |
4 |
5 | # 训练主函数入口
6 | def train():
7 | print('start')
8 |
9 | save_epoch = cfg['save_epoch']
10 | min_avg_loss = 5000
11 | date = cfg['train_date']
12 | # 模型
13 |
14 | model = U_net()
15 | # 读入权重
16 | start_epoch = 0
17 | # model.load_state_dict(torch.load(os.path.join('..', 'weights', 'epoch_001.pkl')))
18 | model.to(cfg['device'])
19 | model.summary(model)
20 |
21 | start_epoch += 1
22 |
23 | # 数据仓库`
24 | dataset = Dataset(os.path.join('..', 'data', date, 'train'))
25 |
26 | train_data_loader = torch.utils.data.DataLoader(dataset=dataset,
27 | batch_size=cfg['batch_size'],
28 | shuffle=True)
29 |
30 | # 优化器
31 | loss_F = torch.nn.MSELoss()
32 | loss_F.to(cfg['device'])
33 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
34 |
35 | for epoch in range(start_epoch, cfg['epochs'], 1):
36 | # model.train()
37 | total_loss = 0.0
38 | min_loss = 10000
39 | max_loss = 10
40 | # 按批次取文件
41 | for index, (x, y) in enumerate(train_data_loader):
42 | img = x.to(cfg['device'])
43 | label = y.to(cfg['device'])
44 |
45 | # print('------------')
46 | # print(img.shape)
47 | # print('------------')
48 |
49 | pre = model(img) # 前向传播
50 | # 计算损失反向传播
51 | # print(pre.shape)
52 | # print('------------')
53 | # print(label.shape)
54 |
55 | loss = loss_F(pre, label) # 计算损失
56 | optimizer.zero_grad() # 因为每次反向传播的时候,变量里面的梯度都要清零
57 | loss.backward() # 变量得到了grad
58 | optimizer.step() # 更新参数
59 | total_loss += loss.item()
60 |
61 | if loss < min_loss:
62 | min_loss = loss
63 |
64 | if loss > max_loss:
65 | max_loss = loss
66 |
67 | # if (index+1) % 5 == 0:
68 | # print('Epoch %d loss %f' % (epoch, total_loss / (index + 1)))
69 |
70 | avg_loss = total_loss/(index+1)
71 |
72 | if avg_loss < min_avg_loss:
73 | print(pre)
74 | print(label)
75 | print('-------------------')
76 | min_avg_loss = avg_loss
77 | torch.save(model.state_dict(), os.path.join('..', "weights", 'min_loss.pth'))
78 |
79 | print('Epoch %d, photo number %d, avg loss %f, min loss %f, max loss %f, min avg loss %f' % (epoch, index+1, avg_loss, min_loss, max_loss, min_avg_loss))
80 |
81 | print('========================')
82 |
83 | # 跑完save_epoch个epoch保存权重
84 | save_name = "epoch_" + str(epoch).zfill(3) + ".pth"
85 | if (epoch) % save_epoch == 0:
86 | torch.save(model.state_dict(), os.path.join('..', "weights", save_name))
87 |
88 |
89 | if __name__ == "__main__":
90 | # 训练
91 | train()
92 |
--------------------------------------------------------------------------------
/offset/result/08_11/test_data/202107170976_4_keypoint.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ExileSaber/KeyPoint-Detection/3fbe892aaed83b305080364d3da1664abb64c53b/offset/result/08_11/test_data/202107170976_4_keypoint.jpg
--------------------------------------------------------------------------------