├── .gitignore
├── README.md
├── data
└── test.yaml
├── deploy
├── app.py
└── test_app.py
├── detect.py
├── example_dataset
└── match
│ ├── images
│ └── 1_10_12
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ ├── 3.jpg
│ │ ├── 4.jpg
│ │ └── 5.jpg
│ ├── labels
│ └── 1_10_12.csv
│ └── videos
│ └── 1_10_12.mp4
├── models
└── tracknet.py
├── ncnn_inference
├── Makefile
├── convert_pnnx.py
└── tracknet.cpp
├── requirements.txt
├── runs
└── .gitkeep
├── tf2torch
├── diff.txt
├── onnx2pt.py
└── track.pt
├── tools
├── Frame_Generator.py
├── Frame_Generator_batch.py
├── Frame_Generator_rally.py
├── check_labels.py
├── handle_Darklabel.py
├── handle_tracknet_dataset.py
└── label_tool.py
├── train.py
├── utils
├── augmentations.py
├── dataloaders.py
└── general.py
└── val.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | __pycache__
3 | runs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TrackNetV2-pytorch
2 |
3 | Paper: **TrackNetV2: Efficient Shuttlecock Tracking Network**
4 |
5 | Original Project(tensorflow): https://nol.cs.nctu.edu.tw:234/open-source/TrackNetv2
6 |
7 | > 官方上传的标注工具、数据集均已失效。del>
8 | >
9 | > The author has now reuploaded the dataset。
10 |
11 | Paper reading:[TrackNetV2论文记录与pytorch复现](https://zhuanlan.zhihu.com/p/624900770)
12 |
13 |
14 |
15 | ## Inference with pytorch weights converted from tensorflow weights:
16 |
17 | ```shell
18 | git apply tf2torch/diff.txt
19 | python detect.py --source xxx.mp4 --weights ./tf2torch/track.pt --view-img # TrackNetv2/3_in_3_out/model906_30
20 | ```
21 |
22 |
23 |
24 | ## Inference:
25 |
26 | ```
27 | python detect.py --source xxx.mp4 --weights xxx.pt --view-img
28 | ```
29 |
30 |
31 |
32 | ## Training:
33 |
34 | ```
35 | # training from scratch
36 | python train.py --data data/match.yaml
37 |
38 | # training from pretrain weight
39 | python train.py --weights xxx.pt --data data/match.yaml
40 |
41 | # resume training
42 | python train.py --data data/match.yaml --resume
43 | ```
44 |
45 |
46 |
47 | ## Evaluation:
48 |
49 | ```shell
50 | python val.py --weights xxx.pt --data data/match.yaml
51 | ```
52 |
53 |
54 |
55 | ## Deployment:
56 |
57 | ```shell
58 | # Server
59 | python deploy/app.py --weights xxx.pt
60 |
61 | # Client
62 | python deploy/test_app.py
63 | ```
64 |
65 |
66 |
67 |
68 |
69 | ## Dataset Preparation:
70 |
71 | ```
72 | # TrackNetV2 dataset
73 | # /home/chg/Badminton/TrackNetV2
74 | # - Amateur
75 | # - Professional
76 | # - Test
77 |
78 | python tools/handle_tracknet_dataset.py /home/chg/Badminton/TrackNetV2/Amateur
79 | python tools/handle_tracknet_dataset.py /home/chg/Badminton/TrackNetV2/Professional
80 | python tools/handle_tracknet_dataset.py /home/chg/Badminton/TrackNetV2/Test
81 |
82 | python tools/Frame_Generator_rally.py /home/chg/Badminton/TrackNetV2/Amateur
83 | python tools/Frame_Generator_rally.py /home/chg/Badminton/TrackNetV2/Professional
84 | python tools/Frame_Generator_rally.py /home/chg/Badminton/TrackNetV2/Test
85 |
86 |
87 | # TrackNetV2 dataset config : data/match.yaml
88 | path: /home/chg/Documents/Badminton/TrackNetV2
89 | train:
90 | - Amateur
91 | - Professional
92 | val:
93 | - Test
94 |
95 | # also you can use follow config for testing
96 | train:
97 | - Test/match1/images/1_05_02
98 | val:
99 | - Test/match2/images/1_03_03
100 |
101 | # or
102 | train:
103 | - Test/match1
104 | val:
105 | - Test/match2
106 |
107 | ```
108 |
109 |
110 |
111 | ## Reference:
112 |
113 | https://github.com/mareksubocz/TrackNet
114 |
115 | https://nol.cs.nctu.edu.tw:234/open-source/TrackNetv2
116 |
117 | https://github.com/ultralytics/yolov5
--------------------------------------------------------------------------------
/data/test.yaml:
--------------------------------------------------------------------------------
1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2 | # COCO 2017 dataset http://cocodataset.org
3 | # Example usage: python train.py --data coco.yaml
4 | # parent
5 | # ├── yolov5
6 | # └── datasets
7 | # └── battle ← downloads here
8 |
9 |
10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
11 | path: ./example_dataset/match # dataset root dir
12 | train:
13 | - images/1_10_12
14 |
15 | val:
16 | - images/1_10_12
17 |
18 | # Classes
19 | # nc: 3 # number of classes
20 | # names: ['mcar-0', 'mpeo-1', 'peo-2']
21 |
--------------------------------------------------------------------------------
/deploy/app.py:
--------------------------------------------------------------------------------
1 | import io
2 | import json
3 |
4 | import sys
5 | from pathlib import Path
6 | FILE = Path(__file__).resolve()
7 | ROOT = FILE.parents[1] # 1 -> root directory
8 | if str(ROOT) not in sys.path:
9 | sys.path.append(str(ROOT)) # add ROOT to PATH
10 |
11 | import cv2
12 | import tempfile
13 | import numpy as np
14 | from argparse import ArgumentParser
15 |
16 | import torch
17 | from torchvision import models
18 | import torchvision.transforms as transforms
19 | from PIL import Image
20 | from flask import Flask, jsonify, request
21 |
22 | from models.tracknet import TrackNet
23 | from utils.general import get_shuttle_position
24 |
25 | # reference: https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html
26 |
27 |
28 | def prediction(video, model, device, imgsz):
29 | vid_cap = cv2.VideoCapture(video)
30 | fps = vid_cap.get(cv2.CAP_PROP_FPS)
31 | w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
32 | h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
33 |
34 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
35 | out = cv2.VideoWriter("./predict.mp4", fourcc, fps, (w, h))
36 |
37 | count = 0
38 | video_end = False
39 | while vid_cap.isOpened():
40 | imgs = []
41 | for _ in range(3):
42 | ret, img = vid_cap.read()
43 |
44 | if not ret:
45 | video_end = True
46 | break
47 | imgs.append(img)
48 |
49 | if video_end:
50 | break
51 |
52 | imgs_torch = []
53 | for img in imgs:
54 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
55 |
56 | img_torch = transforms.ToTensor()(img).to(device) # already [0, 1]
57 | img_torch = transforms.functional.resize(img_torch, imgsz, antialias=True)
58 |
59 | imgs_torch.append(img_torch)
60 |
61 | imgs_torch = torch.cat(imgs_torch, dim=0).unsqueeze(0)
62 |
63 | preds = model(imgs_torch)
64 | preds = preds[0].detach().cpu().numpy()
65 |
66 | y_preds = preds > 0.5
67 | y_preds = y_preds.astype('float32')
68 | y_preds = y_preds*255
69 | y_preds = y_preds.astype('uint8')
70 |
71 | for i in range(3):
72 | (visible, cx_pred, cy_pred) = get_shuttle_position(y_preds[i])
73 | (cx, cy) = (int(cx_pred*w/imgsz[1]), int(cy_pred*h/imgsz[0]))
74 | if visible:
75 | cv2.circle(imgs[i], (cx, cy), 8, (0,0,255), -1)
76 |
77 |
78 | out.write(imgs[i])
79 | print("{} ---- visible: {} cx: {} cy: {}".format(count, visible, cx, cy))
80 |
81 | count += 1
82 |
83 | out.release()
84 | vid_cap.release()
85 |
86 |
87 |
88 |
89 |
90 | def parse_opt():
91 | parser = ArgumentParser()
92 |
93 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w')
94 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.')
95 |
96 | opt = parser.parse_args()
97 |
98 | return opt
99 |
100 |
101 | def main(opt):
102 | f_weights = str(opt.weights)
103 | imgsz = opt.imgsz
104 |
105 | device = "cuda"
106 |
107 | model = TrackNet().to(device)
108 | model.load_state_dict(torch.load(f_weights))
109 | model.eval()
110 | print("initialize TrackNet, load weights: {}".format(f_weights))
111 |
112 | app = Flask(__name__)
113 |
114 | @app.route('/predict', methods=['POST'])
115 | def predict():
116 | if request.method == 'POST':
117 | file = request.files['file']
118 |
119 | # file.save('video.mp4')
120 | file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
121 |
122 | with tempfile.NamedTemporaryFile() as temp:
123 | temp.write(file_bytes)
124 |
125 | prediction(temp.name, model, device, imgsz)
126 |
127 | return 'Video processed successfully'
128 |
129 | app.run()
130 |
131 |
132 | if __name__ == '__main__':
133 | opt = parse_opt()
134 | main(opt)
--------------------------------------------------------------------------------
/deploy/test_app.py:
--------------------------------------------------------------------------------
1 | import requests
2 |
3 | data = {'file': open('/home/chg/Videos/ld/2012.mp4', 'rb')}
4 |
5 | resp = requests.post("http://localhost:5000/predict", files=data)
6 | print(resp.text)
--------------------------------------------------------------------------------
/detect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | import os
5 | import sys
6 | import cv2
7 | import numpy as np
8 | from pathlib import Path
9 | from argparse import ArgumentParser
10 |
11 | from models.tracknet import TrackNet
12 | from utils.general import get_shuttle_position
13 |
14 | # from yolov5 detect.py
15 | FILE = Path(__file__).resolve()
16 | ROOT = FILE.parents[0] # YOLOv5 root directory
17 | if str(ROOT) not in sys.path:
18 | sys.path.append(str(ROOT)) # add ROOT to PATH
19 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
20 |
21 |
22 | def parse_opt():
23 | parser = ArgumentParser()
24 |
25 | parser.add_argument('--source', type=str, default=ROOT / 'example_dataset/match/videos/1_10_12.mp4', help='Path to video.')
26 | parser.add_argument('--save-txt', action='store_true', help='save results to *.csv')
27 | parser.add_argument('--view-img', action='store_true', help='show results')
28 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w')
29 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.')
30 | parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name')
31 |
32 | opt = parser.parse_args()
33 |
34 | return opt
35 |
36 |
37 | def main(opt):
38 | # imgsz = [288, 512]
39 | # imgsz = [360, 640]
40 |
41 | source_name = os.path.splitext(os.path.basename(opt.source))[0]
42 | b_save_txt = opt.save_txt
43 | b_view_img = opt.view_img
44 | d_save_dir = str(opt.project)
45 | f_weights = str(opt.weights)
46 | f_source = str(opt.source)
47 | imgsz = opt.imgsz
48 |
49 | # video_name ---> video_name_pred
50 | source_name = '{}_predict'.format(source_name)
51 |
52 | # runs/detect
53 | if not os.path.exists(d_save_dir):
54 | os.makedirs(d_save_dir)
55 |
56 | # runs/detect/video_name
57 | img_save_path = '{}/{}'.format(d_save_dir, source_name)
58 | if not os.path.exists(img_save_path):
59 | os.makedirs(img_save_path)
60 |
61 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
62 |
63 | model = TrackNet().to(device)
64 | model.load_state_dict(torch.load(f_weights))
65 | model.eval()
66 |
67 | # import ncnn
68 | # net = ncnn.Net()
69 | # net.load_param("./pt_30_optimize.ncnn.param")
70 | # net.load_model("./pt_30_optimize.ncnn.bin")
71 |
72 | vid_cap = cv2.VideoCapture(f_source)
73 | video_end = False
74 |
75 | video_len = vid_cap.get(cv2.CAP_PROP_FRAME_COUNT)
76 | fps = vid_cap.get(cv2.CAP_PROP_FPS)
77 | w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
78 | h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
79 |
80 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
81 | out = cv2.VideoWriter('{}/{}.mp4'.format(d_save_dir, source_name), fourcc, fps, (w, h))
82 |
83 | if b_save_txt:
84 | f_save_txt = open('{}/{}.csv'.format(d_save_dir, source_name), 'w')
85 | f_save_txt.write('frame_num,visible,x,y\n')
86 |
87 | if b_view_img:
88 | cv2.namedWindow(source_name, cv2.WINDOW_NORMAL)
89 | cv2.resizeWindow(source_name, (w, h))
90 |
91 | count = 0
92 | while vid_cap.isOpened():
93 | imgs = []
94 | for _ in range(3):
95 | ret, img = vid_cap.read()
96 |
97 | if not ret:
98 | video_end = True
99 | break
100 | imgs.append(img)
101 |
102 | if video_end:
103 | break
104 |
105 | imgs_torch = []
106 | for img in imgs:
107 | # https://www.geeksforgeeks.org/converting-an-image-to-a-torch-tensor-in-python/
108 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
109 |
110 | img_torch = torchvision.transforms.ToTensor()(img).to(device) # already [0, 1]
111 | img_torch = torchvision.transforms.functional.resize(img_torch, imgsz, antialias=True)
112 |
113 | imgs_torch.append(img_torch)
114 |
115 | imgs_torch = torch.cat(imgs_torch, dim=0).unsqueeze(0)
116 |
117 | preds = model(imgs_torch)
118 | preds = preds[0].detach().cpu().numpy()
119 |
120 | # ncnn
121 | # ex = net.create_extractor()
122 | # ex.input("in0", ncnn.Mat(imgs_torch.squeeze(0).numpy()).clone())
123 | # _, out0 = ex.extract("out0")
124 | # preds = np.array(out0)
125 |
126 | y_preds = preds > 0.5
127 | y_preds = y_preds.astype('float32')
128 | y_preds = y_preds*255
129 | y_preds = y_preds.astype('uint8')
130 |
131 | for i in range(3):
132 | (visible, cx_pred, cy_pred) = get_shuttle_position(y_preds[i])
133 | (cx, cy) = (int(cx_pred*w/imgsz[1]), int(cy_pred*h/imgsz[0]))
134 | if visible:
135 | cv2.circle(imgs[i], (cx, cy), 8, (0,0,255), -1)
136 |
137 | if b_save_txt:
138 | f_save_txt.write('{},{},{},{}\n'.format(count, visible, cx, cy))
139 |
140 | if b_view_img:
141 | cv2.imwrite('{}/{}.png'.format(img_save_path, count), imgs[i])
142 | cv2.imshow(source_name, imgs[i])
143 | cv2.waitKey(1)
144 |
145 | out.write(imgs[i])
146 | print("{} ---- visible: {} cx: {} cy: {}".format(count, visible, cx, cy))
147 |
148 | count += 1
149 |
150 | if cv2.waitKey(1) & 0xFF == ord('q'):
151 | break
152 |
153 | if b_save_txt:
154 | # 每次识别3张,最后可能有1-2张没有识别,补0
155 | while count < video_len:
156 | f_save_txt.write('{},0,0,0\n'.format(count))
157 | count += 1
158 |
159 | f_save_txt.close()
160 |
161 | out.release()
162 | vid_cap.release()
163 | cv2.destroyAllWindows()
164 |
165 |
166 | if __name__ == '__main__':
167 | opt = parse_opt()
168 | main(opt)
169 |
--------------------------------------------------------------------------------
/example_dataset/match/images/1_10_12/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/0.jpg
--------------------------------------------------------------------------------
/example_dataset/match/images/1_10_12/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/1.jpg
--------------------------------------------------------------------------------
/example_dataset/match/images/1_10_12/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/2.jpg
--------------------------------------------------------------------------------
/example_dataset/match/images/1_10_12/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/3.jpg
--------------------------------------------------------------------------------
/example_dataset/match/images/1_10_12/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/4.jpg
--------------------------------------------------------------------------------
/example_dataset/match/images/1_10_12/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/5.jpg
--------------------------------------------------------------------------------
/example_dataset/match/labels/1_10_12.csv:
--------------------------------------------------------------------------------
1 | frame_num,visible,x,y
2 | 0,1,0.6442708333333333,0.424074074074074
3 | 1,1,0.64375,0.4185185185185185
4 | 2,1,0.6427083333333333,0.4175925925925925
5 | 3,1,0.6421875,0.412962962962963
6 | 4,1,0.6416666666666667,0.412037037037037
7 | 5,1,0.6401041666666667,0.3944444444444444
8 |
--------------------------------------------------------------------------------
/example_dataset/match/videos/1_10_12.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/videos/1_10_12.mp4
--------------------------------------------------------------------------------
/models/tracknet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from torchsummary import summary
5 |
6 |
7 | class Conv(nn.Module):
8 | def __init__(self, ic, oc, k=(3,3), p="same", act=True):
9 | super().__init__()
10 | self.conv = nn.Conv2d(ic, oc, kernel_size=k, padding=p)
11 | self.bn = nn.BatchNorm2d(oc)
12 | self.act = nn.ReLU() if act else nn.Identity()
13 |
14 | # self.convs = nn.Sequential(
15 | # nn.Conv2d(ic, oc, kernel_size=k, padding=p),
16 | # nn.ReLU(),
17 | # nn.BatchNorm2d(oc),
18 | # )
19 |
20 | def forward(self, x):
21 | return self.bn(self.act(self.conv(x))) # 和relu-bn-conv不一样?
22 | #return self.convs(x)
23 |
24 |
25 | class TrackNet(nn.Module):
26 | def __init__(self):
27 | super().__init__()
28 |
29 | # VGG16
30 | # self.conv2d_1 = Conv(3, 64) 输入3张灰度图
31 | self.conv2d_1 = Conv(9, 64) # 输入3张RGB图
32 | self.conv2d_2 = Conv(64, 64)
33 | self.max_pooling_1 = nn.MaxPool2d((2,2), stride=(2,2))
34 |
35 | self.conv2d_3 = Conv(64, 128)
36 | self.conv2d_4 = Conv(128, 128)
37 | self.max_pooling_2 = nn.MaxPool2d((2,2), stride=(2,2))
38 |
39 | self.conv2d_5 = Conv(128, 256)
40 | self.conv2d_6 = Conv(256, 256)
41 | self.conv2d_7 = Conv(256, 256)
42 | self.max_pooling_3 = nn.MaxPool2d((2,2), stride=(2,2))
43 |
44 | self.conv2d_8 = Conv(256, 512)
45 | self.conv2d_9 = Conv(512, 512)
46 | self.conv2d_10 = Conv(512, 512)
47 |
48 | # Deconv / UNet
49 | self.up_sampling_1 = nn.UpsamplingNearest2d(scale_factor=2)
50 | # concatenate_1 with conv2d_7, axis = 1
51 |
52 | self.conv2d_11 = Conv(768, 256)
53 | self.conv2d_12 = Conv(256, 256)
54 | self.conv2d_13 = Conv(256, 256)
55 |
56 | self.up_sampling_2 = nn.UpsamplingNearest2d(scale_factor=2)
57 | # concatenate_2 with conv2d_4, axis = 1
58 |
59 | self.conv2d_14 = Conv(384, 128)
60 | self.conv2d_15 = Conv(128, 128)
61 |
62 | self.up_sampling_3 = nn.UpsamplingNearest2d(scale_factor=2)
63 | # concatenate_3 with conv2d_2, axis = 1
64 |
65 | self.conv2d_16 = Conv(192, 64)
66 | self.conv2d_17 = Conv(64, 64)
67 | self.conv2d_18 = nn.Conv2d(64, 3, kernel_size=(1,1), padding='same') # 输出3张图
68 | # self.conv2d_18 = Conv(64, 1, k=(1,1)) 输出1张图
69 |
70 |
71 | def forward(self, x):
72 | # VGG16
73 | x = self.conv2d_1(x)
74 | x1 = self.conv2d_2(x)
75 | x = self.max_pooling_1(x1)
76 |
77 | x = self.conv2d_3(x)
78 | x2 = self.conv2d_4(x)
79 | x = self.max_pooling_2(x2)
80 |
81 | x = self.conv2d_5(x)
82 | x = self.conv2d_6(x)
83 | x3 = self.conv2d_7(x)
84 | x = self.max_pooling_3(x3)
85 |
86 | x = self.conv2d_8(x)
87 | x = self.conv2d_9(x)
88 | x = self.conv2d_10(x)
89 |
90 | # Deconv / UNet
91 | x = self.up_sampling_1(x)
92 | x = torch.concat([x, x3], dim=1)
93 |
94 | x = self.conv2d_11(x)
95 | x = self.conv2d_12(x)
96 | x = self.conv2d_13(x)
97 |
98 | x = self.up_sampling_2(x)
99 | x = torch.concat([x, x2], dim=1)
100 |
101 | x = self.conv2d_14(x)
102 | x = self.conv2d_15(x)
103 |
104 | x = self.up_sampling_3(x)
105 | x = torch.concat([x, x1], dim=1)
106 |
107 | x = self.conv2d_16(x)
108 | x = self.conv2d_17(x)
109 | x = self.conv2d_18(x)
110 |
111 | x = torch.sigmoid(x)
112 |
113 | return x
114 |
115 |
116 | # torch.load 加载权重
117 | # model.load_state_dict 将权重加载到模型中
118 |
119 | # model.state_dict() 模型的权重
120 | # torch.save 保存模型的权重
121 |
122 |
123 | if __name__ == '__main__':
124 | model = TrackNet()
125 | print(summary(model,(9, 288, 512), device="cpu"))
126 | #print(model)
127 |
--------------------------------------------------------------------------------
/ncnn_inference/Makefile:
--------------------------------------------------------------------------------
1 | all:
2 | g++ -o tracknet tracknet.cpp -lncnn -I../../ncnn-20240820-ubuntu-2204-shared/include/ncnn -L../../ncnn-20240820-ubuntu-2204-shared/lib/ `pkg-config --cflags --libs opencv4`
3 |
4 | #export LD_LIBRARY_PATH=../../ncnn-20240820-ubuntu-2204-shared/lib/
5 | #./tracknet ../smash_test.mp4
--------------------------------------------------------------------------------
/ncnn_inference/convert_pnnx.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pnnx
3 |
4 | from pathlib import Path
5 | import sys
6 |
7 | FILE = Path(__file__).resolve()
8 | ROOT = FILE.parents[1]
9 |
10 | if str(ROOT) not in sys.path:
11 | sys.path.append(str(ROOT)) # add ROOT to PATH
12 |
13 | from models.tracknet import TrackNet
14 |
15 |
16 | model = TrackNet().to("cpu")
17 | model.load_state_dict(torch.load("./last.pt"))
18 | model = model.eval()
19 |
20 | x = torch.rand(1, 9, 288, 512)
21 |
22 | opt_model = pnnx.export(model, "./last_opt.pt", x, fp16 = False)
23 |
--------------------------------------------------------------------------------
/ncnn_inference/tracknet.cpp:
--------------------------------------------------------------------------------
1 | // Tencent is pleased to support the open source community by making ncnn available.
2 | //
3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 | //
5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 | // in compliance with the License. You may obtain a copy of the License at
7 | //
8 | // https://opensource.org/licenses/BSD-3-Clause
9 | //
10 | // Unless required by applicable law or agreed to in writing, software distributed
11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 | // specific language governing permissions and limitations under the License.
14 |
15 | #include "net.h"
16 |
17 | #include
18 | #include
19 | #include
20 | // #include
21 |
22 | #include
23 | #include
24 | #include
25 | #include
26 |
27 | cv::Mat hwc2chw(const cv::Mat& src_mat9)
28 | {
29 | std::vector bgr_channels(9);
30 | cv::split(src_mat9, bgr_channels);
31 | for (size_t i = 0; i < bgr_channels.size(); i++)
32 | {
33 | bgr_channels[i] = bgr_channels[i].reshape(1, 1); // reshape为1通道,1行,n列
34 | }
35 | cv::Mat dst_mat;
36 | cv::hconcat(bgr_channels, dst_mat);
37 | return dst_mat;
38 | }
39 |
40 |
41 | void print_shape(const ncnn::Mat &in)
42 | {
43 | std::cout << "d: " << in.d << " c: " << in.c << " w: " << in.w << " h: " << in.h << " cstep: " << in.cstep << std::endl;
44 | }
45 |
46 | std::tuple get_shuttle_position(const cv::Mat binary_pred)
47 | {
48 | if (cv::countNonZero(binary_pred) <= 0)
49 | {
50 | // (visible, cx, cy)
51 | return std::make_tuple(0, 0, 0);
52 | }
53 | else
54 | {
55 | std::vector> cnts;
56 | cv::findContours(binary_pred, cnts, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
57 | assert(cnts.size()!= 0);
58 |
59 | std::vector rects;
60 | for (const auto& ctr : cnts) {
61 | rects.push_back(cv::boundingRect(ctr));
62 | }
63 |
64 | int max_area_idx = 0;
65 | int max_area = rects[max_area_idx].width * rects[max_area_idx].height;
66 |
67 | for (size_t ii = 0; ii < rects.size(); ++ii) {
68 | int area = rects[ii].width * rects[ii].height;
69 | if (area > max_area) {
70 | max_area_idx = ii;
71 | max_area = area;
72 | }
73 | }
74 |
75 | cv::Rect target = rects[max_area_idx];
76 | int cx = target.x + target.width / 2;
77 | int cy = target.y + target.height / 2;
78 |
79 | // (visible, cx, cy)
80 | return std::make_tuple(1, cx, cy);
81 | }
82 | }
83 |
84 |
85 | static int detect_tracknet(const char* video_path)
86 | {
87 | ncnn::Net tracknet;
88 |
89 | // GPU
90 | tracknet.opt.use_vulkan_compute = false;
91 |
92 | if(tracknet.load_param("./last_opt.ncnn.param"))
93 | exit(-1);
94 | if(tracknet.load_model("./last_opt.ncnn.bin"))
95 | exit(-1);
96 |
97 |
98 | cv::VideoCapture vid_cap(video_path);
99 | bool video_end = false;
100 |
101 | int video_len = vid_cap.get(cv::CAP_PROP_FRAME_COUNT);
102 | double fps = vid_cap.get(cv::CAP_PROP_FPS);
103 | int w = vid_cap.get(cv::CAP_PROP_FRAME_WIDTH);
104 | int h = vid_cap.get(cv::CAP_PROP_FRAME_HEIGHT);
105 |
106 | int iw = 512;
107 | int ih = 288;
108 | int ic = 9;
109 |
110 | int count = 0;
111 | while (vid_cap.isOpened())
112 | {
113 | std::vector imgs;
114 | for (int i = 0; i < 3; ++i)
115 | {
116 | cv::Mat img;
117 | bool ret = vid_cap.read(img);
118 | if (!ret)
119 | {
120 | video_end = true;
121 | break;
122 | }
123 |
124 | imgs.push_back(img);
125 | }
126 |
127 | if (video_end)
128 | break;
129 |
130 |
131 | auto start = std::chrono::high_resolution_clock::now();
132 | std::vector imgs_hwc;
133 | for (int i=0; i<3; i++)
134 | {
135 | cv::Mat img;
136 | imgs[i].convertTo(img, CV_32F, 1.0 / 255.0);
137 |
138 | cv::resize(img, img, cv::Size(iw, ih), 0, 0, cv::INTER_LINEAR);
139 |
140 | std::vector bgr_channels(3);
141 | cv::split(img, bgr_channels);
142 |
143 | // 210: bgr - > rgb
144 | imgs_hwc.push_back(bgr_channels[2].reshape(1, 1));
145 | imgs_hwc.push_back(bgr_channels[1].reshape(1, 1));
146 | imgs_hwc.push_back(bgr_channels[0].reshape(1, 1));
147 | }
148 |
149 | // inference need chw !
150 | cv::Mat imgs_chw;
151 | cv::hconcat(imgs_hwc, imgs_chw);
152 |
153 | // d = 1,c = 9, w = 512, h = 288, cstep = 147456
154 | ncnn::Mat in(iw, ih, ic, (void*)imgs_chw.data);
155 | // print_shape(in);
156 | std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start;
157 | std::cout << "Preprocess time taken: " << elapsed.count() << " ms" << std::endl;
158 |
159 | start = std::chrono::high_resolution_clock::now();
160 | ncnn::Extractor ex = tracknet.create_extractor();
161 | ex.input("in0", in);
162 |
163 | ncnn::Mat out;
164 | ex.extract("out0", out);
165 | elapsed = std::chrono::high_resolution_clock::now() - start;
166 | std::cout << "Inference time taken: " << elapsed.count() << " ms" << std::endl;
167 |
168 | // post process
169 | // c = 3
170 | start = std::chrono::high_resolution_clock::now();
171 | for (int i=0; i NWHC
31 | + x = self.bn(x)
32 | + x = x.transpose(1, 3) # NCHW <--- NWHC
33 | +
34 | + return x
35 |
36 |
37 | class TrackNet(nn.Module):
38 | @@ -28,42 +28,42 @@ class TrackNet(nn.Module):
39 |
40 | # VGG16
41 | # self.conv2d_1 = Conv(3, 64) 输入3张灰度图
42 | - self.conv2d_1 = Conv(9, 64) # 输入3张RGB图
43 | - self.conv2d_2 = Conv(64, 64)
44 | + self.conv2d_1 = Conv(9, 64, 512) # 输入3张RGB图
45 | + self.conv2d_2 = Conv(64, 64, 512)
46 | self.max_pooling_1 = nn.MaxPool2d((2,2), stride=(2,2))
47 |
48 | - self.conv2d_3 = Conv(64, 128)
49 | - self.conv2d_4 = Conv(128, 128)
50 | + self.conv2d_3 = Conv(64, 128, 256)
51 | + self.conv2d_4 = Conv(128, 128, 256)
52 | self.max_pooling_2 = nn.MaxPool2d((2,2), stride=(2,2))
53 |
54 | - self.conv2d_5 = Conv(128, 256)
55 | - self.conv2d_6 = Conv(256, 256)
56 | - self.conv2d_7 = Conv(256, 256)
57 | + self.conv2d_5 = Conv(128, 256, 128)
58 | + self.conv2d_6 = Conv(256, 256, 128)
59 | + self.conv2d_7 = Conv(256, 256, 128)
60 | self.max_pooling_3 = nn.MaxPool2d((2,2), stride=(2,2))
61 |
62 | - self.conv2d_8 = Conv(256, 512)
63 | - self.conv2d_9 = Conv(512, 512)
64 | - self.conv2d_10 = Conv(512, 512)
65 | + self.conv2d_8 = Conv(256, 512, 64)
66 | + self.conv2d_9 = Conv(512, 512, 64)
67 | + self.conv2d_10 = Conv(512, 512, 64)
68 |
69 | # Deconv / UNet
70 | self.up_sampling_1 = nn.UpsamplingNearest2d(scale_factor=2)
71 | # concatenate_1 with conv2d_7, axis = 1
72 |
73 | - self.conv2d_11 = Conv(768, 256)
74 | - self.conv2d_12 = Conv(256, 256)
75 | - self.conv2d_13 = Conv(256, 256)
76 | + self.conv2d_11 = Conv(768, 256, 128)
77 | + self.conv2d_12 = Conv(256, 256, 128)
78 | + self.conv2d_13 = Conv(256, 256, 128)
79 |
80 | self.up_sampling_2 = nn.UpsamplingNearest2d(scale_factor=2)
81 | # concatenate_2 with conv2d_4, axis = 1
82 |
83 | - self.conv2d_14 = Conv(384, 128)
84 | - self.conv2d_15 = Conv(128, 128)
85 | + self.conv2d_14 = Conv(384, 128, 256)
86 | + self.conv2d_15 = Conv(128, 128, 256)
87 |
88 | self.up_sampling_3 = nn.UpsamplingNearest2d(scale_factor=2)
89 | # concatenate_3 with conv2d_2, axis = 1
90 |
91 | - self.conv2d_16 = Conv(192, 64)
92 | - self.conv2d_17 = Conv(64, 64)
93 | + self.conv2d_16 = Conv(192, 64, 512)
94 | + self.conv2d_17 = Conv(64, 64, 512)
95 | self.conv2d_18 = nn.Conv2d(64, 3, kernel_size=(1,1), padding='same') # 输出3张图
96 | # self.conv2d_18 = Conv(64, 1, k=(1,1)) 输出1张图
97 |
98 |
--------------------------------------------------------------------------------
/tf2torch/onnx2pt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import onnx
3 | from onnx2torch import convert
4 | from models.tracknet import TrackNet
5 | from torchsummary import summary
6 |
7 | # tensorflow weight
8 | #https://nol.cs.nctu.edu.tw:234/open-source/TrackNetv2/tree/master/3_in_3_out/model906_30
9 |
10 | # model906_30 ---> mode_save ---> track.onnx
11 | #
12 |
13 | # Path to ONNX model
14 | onnx_model_path = './track.onnx'
15 | # You can pass the path to the onnx model to convert it or...
16 |
17 |
18 | # onnx_model = onnx.load(onnx_model_path)
19 | # # https://stackoverflow.com/questions/53176229/strip-onnx-graph-from-its-constants-initializers
20 | # onnx_model.graph.ClearField('initializer')
21 | # torch_model = convert(onnx_model)
22 | torch_model = convert(onnx_model_path)
23 |
24 |
25 | onnx_dict = torch_model.state_dict()
26 |
27 | del onnx_dict['initializers.onnx_initializer_0']
28 | del onnx_dict['initializers.onnx_initializer_1']
29 | del onnx_dict['initializers.onnx_initializer_2']
30 | del onnx_dict['initializers.onnx_initializer_3']
31 | del onnx_dict['initializers.onnx_initializer_4']
32 | del onnx_dict['initializers.onnx_initializer_5']
33 |
34 |
35 |
36 |
37 | track_model = TrackNet()
38 | # print(summary(track_model,(9, 288, 512), device="cpu"))
39 | track_dict = track_model.state_dict()
40 |
41 |
42 |
43 | assert(len(onnx_dict)==len(track_dict))
44 |
45 | convert_dict = {}
46 | # print(track_dict.keys())
47 | # print(onnx_dict.keys())
48 |
49 | for k1, k2 in zip(onnx_dict.keys(), track_dict.keys()):
50 | # print(f'onnx: {k1} shape:{onnx_dict[k1].shape} track: {k2} shape:{track_dict[k2].shape}')
51 | convert_dict[k2] = onnx_dict[k1]
52 |
53 |
54 | torch.save(convert_dict, './track.pt')
55 |
56 |
57 |
58 |
59 | # ValueError: Unknown layer: BatchNormalization.
60 | # ValueError: Unknown loss function: custom_loss.
61 |
62 | # /home/chg/anaconda3/envs/track/lib/python3.10/site-packages/tf2onnx/tf_loader.py
63 |
64 | # line 645
65 | # def custom_loss(y_true, y_pred):
66 | # from tensorflow.python.keras import backend as K
67 | # loss = (-1)*(K.square(1 - y_pred) * y_true * K.log(K.clip(y_pred, K.epsilon(), 1)) + K.square(y_pred) * (1 - y_true) * K.log(K.clip(1 - y_pred, K.epsilon(), 1)))
68 | # return K.mean(loss)
69 |
70 | # def from_keras(model_path, input_names, output_names):
71 | # ...
72 | # custom_objects = {'custom_loss':custom_loss}
73 |
74 |
75 | # size mismatch for conv2d_1.bn.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]).
76 | # size mismatch for conv2d_1.bn.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]).
77 | # onnx: BatchNormalization_0.weight shape:torch.Size([512]) track: conv2d_1.bn.weight shape:torch.Size([64])
78 |
79 | # tensorflow中conv2d默认channels_last,默认情况下BatchNormalization会对最后一个通道C通道进行BN,而使用channels_first后,
80 | # BatchNormalization 输入(N, C, H, W), 学习参数, C
81 |
82 |
83 | # https://keras.io/api/layers/normalization_layers/batch_normalization/
84 | # 在keras和tensorflow中,BatchNormalization默认是对输入的最后一个维度进行归一化,即axis=-1,这通常是特征维度。在pytorch中,BatchNormalization有不同的变体,如BatchNorm1d,BatchNorm2d等,它们默认是对输入的第二个维度进行归一化,即num_features。
--------------------------------------------------------------------------------
/tf2torch/track.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/tf2torch/track.pt
--------------------------------------------------------------------------------
/tools/Frame_Generator.py:
--------------------------------------------------------------------------------
1 | # 将单个视频分解成图像帧
2 | # python Frame_Generator.py Test/match1/videos/1_05_02.mp4 Test/match1/images/1_05_02
3 |
4 | import cv2
5 | import os
6 | import sys
7 | import shutil
8 |
9 |
10 | def extract_video(filePath, outputPath):
11 | if os.path.exists(outputPath):
12 | shutil.rmtree(outputPath)
13 |
14 | os.makedirs(outputPath)
15 |
16 | #Segment the video into frames
17 | cap = cv2.VideoCapture(filePath)
18 | success, count = True, 0
19 | success, image = cap.read()
20 |
21 | while success:
22 | imageFile = os.path.join(outputPath, '{}.jpg'.format(count))
23 | print(imageFile)
24 | cv2.imwrite(imageFile, image)
25 | count += 1
26 | success, image = cap.read()
27 |
28 |
29 | if __name__ == "__main__":
30 | try:
31 | filePath = sys.argv[1]
32 | outputPath = sys.argv[2]
33 | if (not filePath) or (not outputPath):
34 | raise ''
35 | except:
36 | print('usage: python3 Frame_Generator.py ')
37 | exit(1)
38 |
39 |
40 | extract_video(filePath, outputPath)
41 |
42 |
43 |
--------------------------------------------------------------------------------
/tools/Frame_Generator_batch.py:
--------------------------------------------------------------------------------
1 | # 将一个目录下所有的视频依次分解成图像帧, 自动创建同级images目录
2 | # python Frame_Generator_batch.py Test/match1/videos
3 |
4 | import os
5 | import sys
6 | from glob import glob
7 |
8 | from Frame_Generator import extract_video
9 |
10 | # Test/match1/videos/1_05_02.mp4 ----> Test/match1/images/1_05_02/*.jpg
11 | # Test/match1/videos/1_05_02.mp4 ----> Test/match1/images/1_05_02/*.jpg
12 | # ...
13 |
14 |
15 | def extract_videos(videosPath):
16 | for filePath in glob(os.path.join(videosPath, '*mp4')):
17 | tmp, _ = os.path.splitext(filePath) # Test/match1/videos/1_05_02.mp4 ---> Test/match1/videos/1_05_02
18 | imagePath = tmp.replace('videos', 'images') # Test/match1/videos/1_05_02 ---> Test/match1/images/1_05_02
19 |
20 | extract_video(filePath, imagePath)
21 |
22 |
23 | if __name__ == "__main__":
24 | try:
25 | videosPath = sys.argv[1]
26 | if not videosPath:
27 | raise ''
28 | except:
29 | print('usage: python3 Frame_Generator.py ')
30 | exit(1)
31 |
32 |
33 | extract_videos(videosPath)
34 |
--------------------------------------------------------------------------------
/tools/Frame_Generator_rally.py:
--------------------------------------------------------------------------------
1 | # 将一个目录下所有的match比赛,分解成图像帧
2 | # python Frame_Generator_rally.py Test ---> Test/match1、Test/match2、Test/match3等
3 |
4 | import os
5 | import sys
6 |
7 | from Frame_Generator_batch import extract_videos
8 |
9 |
10 | def extract_rally(rallyPath):
11 | for dir in os.listdir(rallyPath):
12 | video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/videos
13 |
14 | print(video_path)
15 | extract_videos(video_path)
16 |
17 |
18 |
19 | if __name__ == "__main__":
20 | try:
21 | rallyPath = sys.argv[1]
22 | if not rallyPath :
23 | raise ''
24 | except:
25 | print('usage: python3 handle_dataset.py ')
26 | exit(1)
27 |
28 |
29 | extract_rally(rallyPath)
30 |
--------------------------------------------------------------------------------
/tools/check_labels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import json
4 | import pandas as pd
5 | import numpy as np
6 |
7 | # 基本的数据校验
8 |
9 | base_path = "/home/chg/Documents/Badminton/merge_dataset/TrackNetV2/"
10 |
11 |
12 | # Amateur/match1
13 | def handle_rally(match_path):
14 | images_path = "{}/images".format(match_path)
15 | labels_path = "{}/labels".format(match_path)
16 |
17 | # 校验json标签
18 | for images_dir in os.listdir(images_path):
19 | for js_dir in glob.glob(os.path.join(images_path+"/"+images_dir, '*.json')):
20 | print(js_dir)
21 |
22 | with open(js_dir, 'r') as file:
23 | data = json.load(file)
24 |
25 | for shape in data["shapes"]:
26 | index = shape["label"]
27 |
28 | try:
29 | index = int(index)
30 | except ValueError:
31 | print("Error index: {}".format(index))
32 | exit()
33 |
34 | if index<=0 or index >=33:
35 | print("Error index: {}".format(index))
36 | exit()
37 |
38 | # 校验csv标签
39 | # for labels_dir in os.listdir(labels_path):
40 | for csv_dir in glob.glob(os.path.join(labels_path+"/", '*.csv')):
41 | images_base_path = csv_dir.replace("labels", "images").split(".")[0]
42 |
43 | print(csv_dir)
44 | df = pd.read_csv(csv_dir)
45 | frame_nums = df["frame_num"].values
46 |
47 | assert(np.all(np.ediff1d(frame_nums) == 1))
48 |
49 | for frame_num in frame_nums:
50 | if not os.path.exists("{}/{}.jpg".format(images_base_path, frame_num)):
51 | print("{}/{}.jpg".format(images_base_path, frame_num))
52 | exit()
53 |
54 | # Amateur
55 | def handle_rally_batch(batch_path):
56 | for rally_dir in os.listdir(batch_path):
57 |
58 | match_path = os.path.join(batch_path, rally_dir)
59 | handle_rally(match_path)
60 |
61 | # from_path:
62 | def handle_base_path(base_path):
63 | for to_batch_dir in os.listdir(base_path):
64 | to_batch_path = os.path.join(base_path, to_batch_dir)
65 |
66 | handle_rally_batch(to_batch_path)
67 |
68 |
69 | handle_base_path(base_path)
--------------------------------------------------------------------------------
/tools/handle_Darklabel.py:
--------------------------------------------------------------------------------
1 |
2 | # match1/csv/xxx_ball.csv ---> match1/labels/xxx.csv
3 | # match1/video ---> match1/videos
4 |
5 | # Frame,Visibility,X,Y ---> frame_num,visible,x,y
6 | # 11,1,621,305 ---> 11,1,0.48515625,0.423611111
7 |
8 |
9 | import os
10 | import sys
11 | import glob
12 | import cv2
13 | import pandas as pd
14 |
15 | try:
16 | rallyPath = sys.argv[1]
17 | if not rallyPath :
18 | raise ''
19 | except:
20 | print('usage: python3 handle_dataset.py ')
21 | exit(1)
22 |
23 |
24 | for dir in os.listdir(rallyPath):
25 |
26 | # video_path = os.path.join(rallyPath, dir, "video") # Test/match1/video
27 | # csv_path = os.path.join(rallyPath, dir, "csv") # Test/match1/csv
28 |
29 | # assert os.path.isdir(video_path), '这不是一个目录'
30 | # assert os.path.isdir(csv_path), '这不是一个目录'
31 |
32 | # # 目录更名, video->videos, csv->labels
33 | # new_video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/videos
34 | # new_csv_path = os.path.join(rallyPath, dir, "labels") # Test/match1/labels
35 |
36 | # os.rename(video_path, new_video_path)
37 | # os.rename(csv_path, new_csv_path)
38 |
39 | new_video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/video
40 | new_csv_path = os.path.join(rallyPath, dir, "labels") # Test/match1/csv
41 |
42 | # Test/match1/videos/1_05_02.mp4
43 | # Test/match1/labels/1_05_02.csv
44 | for file in os.listdir(new_csv_path):
45 | filename, _ = os.path.splitext(file)
46 |
47 | file_path = os.path.join(new_csv_path, file) # Test/match1/labels/1_07_03.csv
48 | video_path = os.path.join(new_video_path, '{}.mp4'.format(filename)) # Test/match1/videos/1_07_03.mp4
49 |
50 | print("handle video: {}".format(video_path))
51 |
52 | if os.path.exists(video_path):
53 | # 确定视频宽高
54 | cap = cv2.VideoCapture(video_path)
55 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
56 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
57 | cap.release()
58 | else:
59 | frame_width, frame_height = 1280, 720
60 |
61 | # 将坐标转化为比例
62 | df = pd.read_csv(file_path)
63 | df.columns = ['frame_num', 'visible', 'x', 'y']
64 | # df = df.rename(columns={'Frame': 'frame_num', 'Visibility': 'visible', 'X': 'x', 'Y': 'y'})
65 |
66 | df['x'] = df['x'].astype(float)
67 | df['y'] = df['y'].astype(float)
68 |
69 | df.loc[:, 'x'] /= frame_width
70 | df.loc[:, 'y'] /= frame_height
71 |
72 | df.to_csv(file_path, index=False)
73 |
--------------------------------------------------------------------------------
/tools/handle_tracknet_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | # 预处理tracknet数据集 https://hackmd.io/Nf8Rh1NrSrqNUzmO0sQKZw
3 |
4 | # match1/csv/xxx_ball.csv ---> match1/labels/xxx.csv
5 | # match1/video ---> match1/videos
6 |
7 | # Frame,Visibility,X,Y ---> frame_num,visible,x,y
8 | # 11,1,621,305 ---> 11,1,0.48515625,0.423611111
9 |
10 |
11 | # 输入参数
12 | # TrackNetV2/Professional
13 | # TrackNetV2/Amateur
14 | # TrackNetV2/Test
15 |
16 | import os
17 | import sys
18 | import glob
19 | import cv2
20 | import pandas as pd
21 |
22 | try:
23 | rallyPath = sys.argv[1]
24 | if not rallyPath :
25 | raise ''
26 | except:
27 | print('usage: python3 handle_dataset.py ')
28 | exit(1)
29 |
30 |
31 | for dir in os.listdir(rallyPath):
32 |
33 | video_path = os.path.join(rallyPath, dir, "video") # Test/match1/video
34 | csv_path = os.path.join(rallyPath, dir, "csv") # Test/match1/csv
35 |
36 | assert os.path.isdir(video_path), '这不是一个目录'
37 | assert os.path.isdir(csv_path), '这不是一个目录'
38 |
39 | # csv更名,去掉'_ball'
40 | for file in glob.glob(os.path.join(csv_path, "*.csv")):
41 | new_name = file.replace('_ball', '')
42 | os.rename(file, new_name)
43 |
44 | # 目录更名, video->videos, csv->labels
45 | new_video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/videos
46 | new_csv_path = os.path.join(rallyPath, dir, "labels") # Test/match1/labels
47 |
48 | os.rename(video_path, new_video_path)
49 | os.rename(csv_path, new_csv_path)
50 |
51 |
52 | # Test/match1/videos/1_05_02.mp4
53 | # Test/match1/labels/1_05_02.csv
54 | for file in os.listdir(new_csv_path):
55 | filename, _ = os.path.splitext(file)
56 |
57 | file_path = os.path.join(new_csv_path, file) # Test/match1/labels/1_07_03.csv
58 | video_path = os.path.join(new_video_path, '{}.mp4'.format(filename)) # Test/match1/videos/1_07_03.mp4
59 |
60 | # 确定视频宽高
61 | cap = cv2.VideoCapture(video_path)
62 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
63 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
64 | cap.release()
65 |
66 | # 将坐标转化为比例
67 | df = pd.read_csv(file_path)
68 | df = df.rename(columns={'Frame': 'frame_num', 'Visibility': 'visible', 'X': 'x', 'Y': 'y'})
69 | df['x'] = df['x'].astype(float)
70 | df['y'] = df['y'].astype(float)
71 |
72 | df.loc[:, 'x'] /= frame_width
73 | df.loc[:, 'y'] /= frame_height
74 |
75 | df.to_csv(file_path, index=False)
76 |
--------------------------------------------------------------------------------
/tools/label_tool.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | import cv2
3 | from enum import Enum
4 | import pandas as pd
5 | import numpy as np
6 | import os
7 | import argparse
8 | import sys
9 | import time
10 |
11 | # python label_tool.py ./dataset/match1/videos/1_10_12.mp4 --csv_dir dataset/match/labels
12 | # 若不加csv_dir,则会默认从mp4文件同级目录读取csv
13 |
14 |
15 | # state 0:hidden 1:visible
16 | state_name = ['HIDDEN', 'VISIBLE']
17 |
18 | keybindings = {
19 | 'next': [ ord('n') ],
20 | 'prev': [ ord('p')],
21 |
22 | 'piece_start': [ ord('s'), ], # 裁剪开始帧
23 | 'piece_end': [ ord('e'), ], # 裁剪结束帧
24 |
25 | 'first_frame': [ ord('z'), ],
26 | 'last_frame': [ ord('x'), ],
27 |
28 | 'forward_frames': [ ord('f'), ], # 前进36帧
29 | 'backward_frames': [ ord('b'), ], # 后退36帧
30 |
31 | 'circle_grow': [ ord('='), ord('+') ],
32 | 'circle_shrink': [ ord('-'), ],
33 |
34 | 'quit': [ ord('q'), ],
35 | }
36 |
37 |
38 |
39 | class VideoPlayer():
40 | def __init__(self, opt) -> None:
41 | self.jump = 36
42 |
43 | self.cap = cv2.VideoCapture(opt.video_path)
44 | self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
45 | self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
46 | self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
47 | #self.fourcc = int(self.cap.get(cv.CAP_PROP_FOURCC))
48 | # self.fourcc = cv2.VideoWriter_fourcc('H', 'E', 'V', 'C')
49 | self.fourcc = cv2.VideoWriter_fourcc(*'mp4v')
50 | # self.fourcc = cv2.VideoWriter_fourcc(*'XVID')
51 | self.fps = self.cap.get(cv2.CAP_PROP_FPS)
52 | self.bitrate = self.cap.get(cv2.CAP_PROP_BITRATE)
53 |
54 | # Check video lens!
55 | # ret, frame = self.cap.read()
56 | # frame_cnt = 1
57 | # while ret:
58 | # ret, frame = self.cap.read()
59 | # if ret:
60 | # frame_cnt += 1
61 | # else:
62 | # print("Waringing: {} frame decode error!".format(frame_cnt))
63 |
64 | # if self.frames != frame_cnt:
65 | # print(" self.frames: {}, frame_cnt: {} ".format(self.frames, frame_cnt))
66 | # self.frames = frame_cnt
67 |
68 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
69 |
70 |
71 | self.video_path = Path(opt.video_path)
72 | self.circle_size = 5
73 | if opt.csv_dir is None:
74 | self.csv_path = self.video_path.with_suffix('.csv')
75 | else:
76 | self.csv_path = Path(opt.csv_dir) / Path(self.video_path.stem).with_suffix('.csv')
77 |
78 | self.window = cv2.namedWindow('Frame', cv2.WINDOW_NORMAL)
79 | cv2.resizeWindow('Frame', 1280, 720)
80 |
81 | _, self.frame = self.cap.read()
82 | self.frame_num = 0
83 |
84 | self.piece_start = 0
85 | self.piece_end = 0
86 |
87 | if os.path.exists(self.csv_path):
88 | self.info = pd.read_csv(self.csv_path)
89 |
90 | if len(self.info.index) != self.frames:
91 | print("pd len: {}, camera len: {}".format(len(self.info.index), self.frames))
92 | print("Number of frames in video and dictionary are not the same!")
93 | print("Fail to load!")
94 | exit(1)
95 |
96 |
97 | # self.info = {'frame_num':[], 'visible':[], 'x':[], 'y':[]}
98 |
99 | # for idx in range(self.frames):
100 | # self.info['frame_num'].append(idx)
101 | # self.info['visible'].append(0)
102 | # self.info['x'].append(0)
103 | # self.info['y'].append(0)
104 |
105 | # print("pandas dataframe len: {}".format(len(self.info)))
106 |
107 | else:
108 | self.info = {k: list(v.values()) for k, v in self.info.to_dict().items()}
109 | print("Load labeled {} successfully.".format(self.csv_path))
110 | else:
111 | print("Create new dictionary")
112 |
113 | self.info = {'frame_num':[], 'visible':[], 'x':[], 'y':[]}
114 |
115 | for idx in range(self.frames):
116 | self.info['frame_num'].append(idx)
117 | self.info['visible'].append(0)
118 | self.info['x'].append(0)
119 | self.info['y'].append(0)
120 |
121 | print("pandas dataframe len: {}".format(len(self.info)))
122 |
123 | cv2.setMouseCallback('Frame',self.markBall)
124 | self.display()
125 |
126 |
127 | def save_piece(self):
128 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.piece_start)
129 |
130 | out = cv2.VideoWriter('{}_{}.mp4'.format(self.piece_start+1, self.piece_end+1), self.fourcc, self.fps, (self.width, self.height))
131 |
132 | frame_cnt = self.piece_start
133 | while frame_cnt <= self.piece_end:
134 | ret, frame = self.cap.read()
135 | out.write(frame)
136 |
137 | frame_cnt += 1
138 |
139 | out.release()
140 |
141 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num)
142 | print("save piece succefully!")
143 |
144 | def markBall(self, event, x, y, flags, param):
145 | x /= self.width
146 | y /= self.height
147 | if event == cv2.EVENT_LBUTTONDOWN:
148 | self.info['frame_num'][self.frame_num] = self.frame_num
149 | self.info['x'][self.frame_num] = x
150 | self.info['y'][self.frame_num] = y
151 | self.info['visible'][self.frame_num] = 1
152 |
153 | elif event == cv2.EVENT_RBUTTONDBLCLK:
154 | self.info['frame_num'][self.frame_num] = self.frame_num
155 | self.info['x'][self.frame_num] = 0
156 | self.info['y'][self.frame_num] = 0
157 | self.info['visible'][self.frame_num] = 0
158 |
159 |
160 | def display(self):
161 | res_frame = self.frame.copy()
162 | res_frame = cv2.putText(res_frame, state_name[self.info['visible'][self.frame_num]], (100, 110), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 2, cv2.LINE_AA)
163 | res_frame = cv2.putText(res_frame, "Frame: {}, Total: {}".format(int(self.frame_num+1), int(self.frames)), (100, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2, cv2.LINE_AA)
164 | res_frame = cv2.putText(res_frame, "Piece: {}-{}".format(int(self.piece_start+1), int(self.piece_end+1)), (100, 170), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2, cv2.LINE_AA)
165 |
166 | if self.info['visible'][self.frame_num]:
167 | x = int(self.info['x'][self.frame_num] * self.width)
168 | y = int(self.info['y'][self.frame_num] * self.height)
169 | cv2.circle(res_frame, (x, y), self.circle_size, (0, 0, 255), -1)
170 |
171 | cv2.imshow('Frame', res_frame)
172 |
173 | #print("frame num: {}".format(self.frame_num))
174 | #print(type(self.frame))
175 | #assert(ret==True)
176 | # frame_num 0---->frames-1
177 | def main_loop(self):
178 | key = cv2.waitKeyEx(1)
179 | if key in keybindings['first_frame']:
180 | self.frame_num = 0
181 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num)
182 | ret, self.frame = self.cap.read()
183 |
184 | assert(ret==True)
185 |
186 | elif key in keybindings['last_frame']:
187 | self.frame_num = self.frames-1
188 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num) # cap.set时, frame_num不用加1
189 | ret, self.frame = self.cap.read()
190 |
191 | print(type(self.frame))
192 | assert(ret==True)
193 |
194 |
195 | elif key in keybindings['next']:
196 | if self.frame_num < self.frames-1:
197 | ret, self.frame = self.cap.read()
198 | self.frame_num += 1
199 |
200 | assert(ret==True)
201 |
202 | elif key in keybindings['prev']:
203 | time.sleep(0.01)
204 | if self.frame_num > 0:
205 | self.frame_num -= 1
206 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num)
207 | _, self.frame = self.cap.read()
208 |
209 |
210 | elif key in keybindings['forward_frames']:
211 | if self.frame_num < self.frames-1:
212 | for _ in range(self.jump):
213 | if self.frame_num == self.frames-2: # 倒数第二帧,最后一帧使用read()
214 | break
215 |
216 | self.cap.grab() # cap.grab跳过帧, frame_num加1
217 | self.frame_num += 1
218 |
219 | ret, self.frame = self.cap.read()
220 | self.frame_num += 1
221 |
222 |
223 | elif key in keybindings['backward_frames']:
224 | if self.frame_num < self.jump:
225 | self.frame_num = 0
226 | else:
227 | self.frame_num -= self.jump
228 |
229 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num)
230 | ret, self.frame = self.cap.read()
231 |
232 | assert(ret==True)
233 |
234 |
235 | elif key in keybindings['circle_grow']:
236 | self.circle_size += 1
237 | elif key in keybindings['circle_shrink']:
238 | self.circle_size -= 1
239 |
240 |
241 | elif key in keybindings['piece_start']:
242 | self.piece_start = self.frame_num
243 |
244 | elif key in keybindings['piece_end']:
245 | self.piece_end = self.frame_num
246 | self.save_piece()
247 |
248 |
249 | elif key in keybindings['quit']:
250 | self.finish()
251 | return
252 |
253 |
254 | self.display()
255 |
256 |
257 | def finish(self):
258 | self.cap.release()
259 | cv2.destroyAllWindows()
260 | df = pd.DataFrame.from_dict(self.info).sort_values(by=['frame_num'], ignore_index=True)
261 | df.to_csv(self.csv_path, index=False)
262 |
263 |
264 | def __del__(self):
265 | self.finish()
266 |
267 |
268 | def parse_opt():
269 | parser = argparse.ArgumentParser()
270 | parser.add_argument('video_path', type=str, nargs='?', default=None, help='Path to the video file.')
271 | parser.add_argument('--csv_dir', type=str, default=None, help='Path to the directory where csv file should be saved. If not specified, csv file will be saved in the same directory as the video file.')
272 | parser.add_argument('--remove_duplicate_frames', type=bool, default=False, help='Should identical consecutie frames be reduces to one frame.')
273 | opt = parser.parse_args()
274 | return opt
275 |
276 |
277 | def remove_duplicate_frames(video_path, output_path):
278 | # Open the video file
279 | vid = cv2.VideoCapture(video_path)
280 |
281 | # Set the frame width and height
282 | frame_width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
283 | frame_height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
284 | fps = int(vid.get(cv2.CAP_PROP_FPS))
285 |
286 | # Create a VideoWriter object for the output video file
287 | out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))
288 |
289 | # Read and process the frames one by one
290 | previous_frame = None
291 | while True:
292 | # Read the next frame
293 | success, frame = vid.read()
294 |
295 | # If we reached the end of the video, break the loop
296 | if not success:
297 | break
298 |
299 | # If the current frame is not a duplicate, write it to the output video
300 | if previous_frame is None or cv2.PSNR(frame, previous_frame) < 40.:
301 | out.write(frame)
302 |
303 | # Update the previous frame
304 | previous_frame = frame
305 | print('finished removing duplicates')
306 |
307 |
308 |
309 | if __name__ == '__main__':
310 | opt = parse_opt()
311 |
312 | if opt.video_path is None:
313 | if getattr(sys, 'frozen', False):
314 | application_path = os.path.dirname(sys.executable)
315 | elif __file__:
316 | application_path = os.path.dirname(__file__)
317 | p = Path(application_path)
318 | video_path = next(p.glob('*.mp4'))
319 | toRemove = input('Should duplicated, consecutive frames be deleted? Insert "y" or "n": \n')
320 | if toRemove == 'y':
321 | bez_duplikatow_video_path = str(video_path.with_stem(video_path.stem + '_no_dups'))
322 | remove_duplicate_frames(str(video_path), bez_duplikatow_video_path)
323 | video_path = bez_duplikatow_video_path
324 | opt.video_path = str(video_path)
325 |
326 | # run as a CLI script
327 | elif opt.remove_duplicate_frames == True:
328 | remove_duplicate_frames(opt.video_path, opt.video_path)
329 |
330 | player = VideoPlayer(opt)
331 | while(player.cap.isOpened()):
332 | player.main_loop()
333 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | import os
5 | import sys
6 | import numpy as np
7 | from pathlib import Path
8 | from tqdm import tqdm
9 | from argparse import ArgumentParser
10 | from tensorboardX import SummaryWriter
11 |
12 | from models.tracknet import TrackNet
13 | from utils.dataloaders import create_dataloader
14 | from utils.general import check_dataset, outcome, evaluation, tensorboard_log
15 |
16 |
17 | # from yolov5 detect.py
18 | FILE = Path(__file__).resolve()
19 | ABS_ROOT = FILE.parents[0] # YOLOv5 root directory
20 | if str(ABS_ROOT) not in sys.path:
21 | sys.path.append(str(ABS_ROOT)) # add ROOT to PATH
22 | ROOT = Path(os.path.relpath(ABS_ROOT, Path.cwd())) # relative
23 |
24 |
25 | def wbce_loss(y_true, y_pred):
26 | return -1*(
27 | ((1-y_pred)**2) * y_true * torch.log(torch.clamp(y_pred, min=1e-07, max=1)) +
28 | (y_pred**2) * (1-y_true) * torch.log(torch.clamp(1-y_pred, min=1e-07, max=1))
29 | ).sum()
30 |
31 |
32 | def validation_loop(device, model, val_loader, log_writer, epoch):
33 | model.eval()
34 |
35 | loss_sum = 0
36 | TP = TN = FP1 = FP2 = FN = 0
37 |
38 | with torch.inference_mode():
39 | pbar = tqdm(val_loader, ncols=180)
40 | for batch_index, (X, y) in enumerate(pbar):
41 | X, y = X.to(device), y.to(device)
42 | y_pred = model(X)
43 |
44 | loss_sum += wbce_loss(y, y_pred).item()
45 |
46 | y_ = y.detach().cpu().numpy()
47 | y_pred_ = y_pred.detach().cpu().numpy()
48 |
49 | y_pred_ = (y_pred_ > 0.5).astype('float32')
50 | (tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_)
51 | TP += tp
52 | TN += tn
53 | FP1 += fp1
54 | FP2 += fp2
55 | FN += fn
56 |
57 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN)
58 |
59 | pbar.set_description('Val loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \
60 | loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall))
61 |
62 | tensorboard_log(log_writer, "Val", loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, epoch)
63 |
64 | return loss_sum/len(val_loader)
65 |
66 |
67 | def training_loop(device, model, optimizer, lr_scheduler, train_loader, val_loader, start_epoch, epochs, save_dir):
68 | best_val_loss = float('inf')
69 |
70 | checkpoint_period = 3
71 | log_period = 100
72 |
73 | log_dir = '{}/logs'.format(save_dir)
74 | if not os.path.exists(log_dir):
75 | os.makedirs(log_dir)
76 |
77 | log_writer = SummaryWriter(log_dir)
78 |
79 | for epoch in range(start_epoch, epochs):
80 | print("\n==================================================================================================")
81 | tqdm.write("Epoch: {} / {}\n".format(epoch, epochs))
82 | running_loss = 0.0
83 | TP = TN = FP1 = FP2 = FN = 0
84 |
85 | model.train()
86 | pbar = tqdm(train_loader, ncols=180)
87 | for batch_index, (X, y) in enumerate(pbar):
88 | X, y = X.to(device), y.to(device)
89 | optimizer.zero_grad()
90 |
91 | y_pred = model(X)
92 |
93 | loss = wbce_loss(y, y_pred)
94 | loss.backward()
95 |
96 | optimizer.step()
97 |
98 | running_loss += loss.item()
99 |
100 | y_ = y.detach().cpu().numpy()
101 | y_pred_ = y_pred.detach().cpu().numpy()
102 |
103 | y_pred_ = (y_pred_ > 0.5).astype('float32')
104 | (tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_)
105 | TP += tp
106 | TN += tn
107 | FP1 += fp1
108 | FP2 += fp2
109 | FN += fn
110 |
111 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN)
112 |
113 | pbar.set_description('Train loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \
114 | running_loss / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall))
115 |
116 | if batch_index % log_period == 0:
117 | with torch.inference_mode():
118 | images = [
119 | torch.unsqueeze(y[0,0,:,:], 0).repeat(3,1,1).cpu(),
120 | torch.unsqueeze(y_pred[0,0,:,:], 0).repeat(3,1,1).cpu(),
121 | ]
122 |
123 | images.append(X[0,(0,1,2),:,:].cpu())
124 | res = X[0, (0,1,2),:,:] * y[0,0,:,:]
125 |
126 | images.append(res.cpu())
127 | grid = torchvision.utils.make_grid(images, nrow=1)
128 |
129 | torchvision.utils.save_image(grid, '{}/epoch_{}_batch{}.png'.format(log_dir, epoch, batch_index))
130 |
131 | if val_loader is not None:
132 | best = False
133 | val_loss = validation_loop(device, model, val_loader, log_writer, epoch)
134 | if val_loss < best_val_loss:
135 | best_val_loss = val_loss
136 | best = True
137 |
138 | if epoch % checkpoint_period == checkpoint_period - 1:
139 | tqdm.write('\n--- Saving weights to: {}/last.pt ---'.format(save_dir))
140 | torch.save(model.state_dict(), '{}/last.pt'.format(save_dir))
141 |
142 | if best:
143 | tqdm.write('--- Saving weights to: {}/best.pt ---'.format(save_dir))
144 | torch.save(model.state_dict(), '{}/best.pt'.format(save_dir))
145 |
146 | tensorboard_log(log_writer, "Train", running_loss / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, epoch)
147 |
148 | print('lr: {}'.format(lr_scheduler.get_last_lr()))
149 | lr_scheduler.step()
150 |
151 | if epoch%10 == 0:
152 | ckpt_dir = '{}/checkpoint'.format(save_dir)
153 | if not os.path.exists(ckpt_dir):
154 | os.makedirs(ckpt_dir)
155 |
156 | cur_ckpt = '{}/{}/ckpt_{}.pt'.format(ABS_ROOT, ckpt_dir, epoch)
157 | latest_ckpt = '{}/{}/ckpt_latest.pt'.format(ABS_ROOT, ckpt_dir)
158 |
159 | checkpoint = {
160 | "net": model.state_dict(),
161 | 'optimizer': optimizer.state_dict(),
162 | "epoch": epoch+1,
163 | 'lr_scheduler': lr_scheduler.state_dict()
164 | }
165 |
166 | torch.save(checkpoint, cur_ckpt)
167 | os.system('ln -sf {} {}'.format(cur_ckpt, latest_ckpt))
168 | print("save checkpoint {}".format(cur_ckpt))
169 |
170 |
171 | def parse_opt():
172 | parser = ArgumentParser()
173 |
174 | parser.add_argument('--data', type=str, default=ROOT / 'data/match/test.yaml', help='Path to dataset.')
175 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.')
176 | parser.add_argument('--epochs', type=int, default=100, help='total training epochs')
177 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w')
178 | parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
179 | parser.add_argument('--project', default=ROOT / 'runs/train', help='save results to project/name')
180 | parser.add_argument('--resume', action='store_true', help='whether load checkpoint for resume')
181 |
182 | opt = parser.parse_args()
183 | return opt
184 |
185 |
186 | def main(opt):
187 | d_save_dir = str(opt.project)
188 | f_weights = str(opt.weights)
189 | epochs = opt.epochs
190 | batch_size = opt.batch_size
191 | f_data = str(opt.data)
192 | imgsz = opt.imgsz
193 |
194 | start_epoch = 0
195 |
196 | data_dict = check_dataset(f_data)
197 | train_path, val_path = data_dict['train'], data_dict['val']
198 |
199 | if not os.path.exists(d_save_dir):
200 | os.makedirs(d_save_dir)
201 |
202 |
203 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
204 | model = TrackNet().to(device)
205 |
206 | optimizer = torch.optim.Adadelta(model.parameters(), lr=0.99)
207 |
208 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs)
209 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9)
210 |
211 | if opt.resume:
212 | default_ckpt = "{}/checkpoint/ckpt_latest.pt".format(d_save_dir)
213 | checkpoint = torch.load(default_ckpt)
214 |
215 | model.load_state_dict(checkpoint['net'])
216 | optimizer.load_state_dict(checkpoint['optimizer'])
217 | start_epoch = checkpoint['epoch']
218 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
219 |
220 | print("load checkpoint {}".format(default_ckpt))
221 | else:
222 | if os.path.exists(f_weights):
223 | print("load pretrain weights {}".format(f_weights))
224 | model.load_state_dict(torch.load(f_weights))
225 | else:
226 | print("train from scratch")
227 |
228 | train_loader = create_dataloader(train_path, imgsz, batch_size=batch_size, augment=True, shuffle=True) # augment in training
229 | val_loader = create_dataloader(val_path, imgsz, batch_size=batch_size)
230 |
231 |
232 | training_loop(device, model, optimizer, lr_scheduler, train_loader, val_loader, start_epoch, epochs, d_save_dir)
233 |
234 |
235 |
236 | if __name__ == '__main__':
237 | opt = parse_opt()
238 | main(opt)
--------------------------------------------------------------------------------
/utils/augmentations.py:
--------------------------------------------------------------------------------
1 | # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2 | """Image augmentation functions."""
3 |
4 | import math
5 | import random
6 |
7 | import cv2
8 | import numpy as np
9 | import torch
10 | import torchvision.transforms as T
11 | import torchvision.transforms.functional as TF
12 | import albumentations as A
13 |
14 |
15 | class Albumentations:
16 | # YOLOv5 Albumentations class (optional, only used if package is installed)
17 | def __init__(self, imgsz=[288, 512]):
18 | self.transform = None
19 | prefix = "albumentations: "
20 |
21 | T = [
22 | A.RandomResizedCrop(height=imgsz[0], width=imgsz[1], scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
23 | A.Blur(p=0.1),
24 | A.MedianBlur(p=0.1),
25 | A.ToGray(p=0.1),
26 | A.CLAHE(p=0.1),
27 | A.RandomBrightnessContrast(p=0.0),
28 | A.RandomGamma(p=0.0),
29 | A.ImageCompression(quality_lower=75, p=0.0),
30 | ] # transforms
31 | self.transform = A.Compose(T, keypoint_params=A.KeypointParams(format="xy", remove_invisible=False))
32 |
33 | print(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p))
34 |
35 |
36 | def __call__(self, im, kps, p=1.0):
37 | if self.transform and random.random() < p:
38 | new = self.transform(image=im, keypoints=kps) # transformed
39 | im, kps = new["image"], np.array(new["keypoints"])
40 |
41 | return im, kps
42 |
43 |
44 | def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
45 | # HSV color-space augmentation
46 | if hgain or sgain or vgain:
47 | r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
48 | hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV))
49 | dtype = im.dtype # uint8
50 |
51 | x = np.arange(0, 256, dtype=r.dtype)
52 | lut_hue = ((x * r[0]) % 180).astype(dtype)
53 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
54 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
55 |
56 | im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
57 | cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed
58 |
59 |
60 | def random_perspective(
61 | im, kps, degrees=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0)
62 | ):
63 | # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
64 |
65 | height = im.shape[0] + border[0] * 2 # shape(h,w,c)
66 | width = im.shape[1] + border[1] * 2
67 |
68 | # Center
69 | C = np.eye(3)
70 | C[0, 2] = -im.shape[1] / 2 # x translation (pixels)
71 | C[1, 2] = -im.shape[0] / 2 # y translation (pixels)
72 |
73 | # Perspective
74 | P = np.eye(3)
75 | P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
76 | P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
77 |
78 | # Rotation and Scale
79 | R = np.eye(3)
80 | a = random.uniform(-degrees, degrees)
81 | # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
82 | s = random.uniform(1 - scale, 1 + scale)
83 | # s = 2 ** random.uniform(-scale, scale)
84 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
85 |
86 | # Shear
87 | S = np.eye(3)
88 | S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
89 | S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
90 |
91 | # Translation
92 | T = np.eye(3)
93 | T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
94 | T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
95 |
96 | # Combined rotation matrix
97 | M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
98 | if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
99 | if perspective:
100 | im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114))
101 | else: # affine
102 | im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114))
103 |
104 | # Visualize
105 | # import matplotlib.pyplot as plt
106 | # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel()
107 | # ax[0].imshow(im[:, :, ::-1]) # base
108 | # ax[1].imshow(im2[:, :, ::-1]) # warped
109 | # plt.show()
110 |
111 | # Transform label coordinates
112 | n = len(kps)
113 |
114 | xy = np.ones((n, 3))
115 | xy[:, :2] = kps # 齐次化坐标
116 |
117 | xy = xy @ M.T # transform
118 | xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]) # perspective rescale or affine
119 |
120 | # clip
121 | xy[:, 0] = xy[:, 0].clip(0, width)
122 | xy[:, 1] = xy[:, 1].clip(0, height)
123 |
124 | return im, xy
125 |
126 |
127 | def random_flip(im, kps, p=0.5):
128 | # Flip left-right
129 | if random.random() < p:
130 | im = np.fliplr(im)
131 | kps[:, 0] = im.shape[1] - kps[:, 0] # width
132 |
133 | return im, kps
134 |
135 |
--------------------------------------------------------------------------------
/utils/dataloaders.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | from torch.utils.data import Dataset, DataLoader
5 |
6 | import os
7 | import time
8 | import random
9 | import cv2
10 | import pandas as pd
11 | import numpy as np
12 |
13 | from utils.augmentations import random_perspective, Albumentations, augment_hsv, random_flip
14 |
15 |
16 | class ToTensor:
17 | # YOLOv5 ToTensor class for image preprocessing
18 | def __init__(self, half=False):
19 | super().__init__()
20 | self.half = half
21 |
22 | def __call__(self, im): # im = np.array HWC in BGR order
23 | im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
24 | im = torch.from_numpy(im) # to torch
25 | im = im.half() if self.half else im.float() # uint8 to fp16/32
26 | im /= 255.0 # 0-255 to 0.0-1.0
27 | return im
28 |
29 |
30 | # num_workers https://zhuanlan.zhihu.com/p/568076554
31 | # batch size 20 nw 1 ---> 1.35 nw 8 ---> 1.75
32 | def create_dataloader(path,
33 | imgsz=[288, 512],
34 | batch_size=1,
35 | sq=3,
36 | augment=False,
37 | workers=8,
38 | shuffle=False):
39 | print("create dataloader image size: {}".format(imgsz))
40 |
41 | dataset = LoadImagesAndLabels(path, imgsz, batch_size, sq, augment)
42 |
43 | batch_size = min(batch_size, len(dataset))
44 | nd = torch.cuda.device_count() # number of CUDA devices
45 | nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
46 | # print("num_workers: {}".format(nw))
47 |
48 | return DataLoader(dataset, batch_size=batch_size, num_workers=nw, shuffle=shuffle)
49 |
50 | # assert match/images exist, return path/images/xxx list
51 | def get_match_image_list(match_path):
52 | base_images = os.path.join(match_path, "images")
53 | assert os.path.exists(base_images), base_images+" is invalid"
54 |
55 | image_dir_list = [f.path for f in os.scandir(base_images) if f.is_dir()]
56 | return image_dir_list
57 |
58 | # return all match image list
59 | def get_rally_image_list(rally_path):
60 | image_dir_list = []
61 |
62 | for match_name in os.listdir(rally_path):
63 | image_dir_list.extend(get_match_image_list(os.path.join(rally_path, match_name)))
64 |
65 | return image_dir_list
66 |
67 | class LoadImagesAndLabels(Dataset):
68 | def __init__(self,
69 | path,
70 | imgsz=[288, 512],
71 | batch_size=1,
72 | sq=3,# 网络输入几张图
73 | augment=False,
74 | ):
75 | self.imgsz = imgsz
76 | self.path = path
77 | self.batch_size = batch_size
78 | self.sq = sq
79 | self.augment = augment
80 | self.albumentations = Albumentations(imgsz) if augment else None
81 |
82 | self.image_dir_list = [] # 所有的图片目录
83 | self.label_path_list = [] # 所有的样本路径,文件, 与image_dir_list的元素一一对应
84 | self.df_list = [] # labels
85 | self.lens = [] # 总共的样本数量
86 |
87 |
88 | # image : "./dataset1/match2/images/1_0_1"
89 | # image list : ["./dataset1/match2/images/1_0_1", "./dataset1/match2/images/1_1_1"]
90 |
91 | # match : "./dataset1/match2"
92 | # match list : ["./dataset1/match1", "./dataset1/match2"]
93 |
94 | # rally : "Professional"
95 | # rally list : ["Amateur", "Professional"]
96 |
97 | for p in path if isinstance(path, list) else [path]:
98 | if p[-1] == '/':
99 | p = p[:-1]
100 |
101 | if "images" in p: # image
102 | self.image_dir_list.append(p)
103 | elif "match" in p: # match
104 | self.image_dir_list.extend(get_match_image_list(p))
105 | else: # rally
106 | self.image_dir_list.extend(get_rally_image_list(p))
107 |
108 |
109 | # 校验csv标签长度和img目录文件数量是否一致
110 |
111 | print("\n")
112 | print(self.image_dir_list)
113 |
114 | # TODO::::::::
115 | # Check cache
116 |
117 | # 读取csv
118 | print("\n")
119 | for csv_path in self.image_dir_list:
120 | label_path = "{}.{}".format(csv_path.replace('images', 'labels'), "csv")
121 | self.label_path_list.append(label_path)
122 |
123 | df = pd.read_csv(label_path)
124 | df_len = len(df.index)
125 |
126 | self.lens.append(df_len - self.sq + 1)
127 | self.df_list.append(df)
128 | print("{} len: {}".format(csv_path, df_len))
129 | print("\n")
130 |
131 | def __len__(self):
132 | return sum(self.lens)
133 |
134 | def __getitem__(self, index):
135 | rel_index = index
136 |
137 | # 判断当前样本在哪个df中, index从0开始
138 | for ix in range(len(self.lens)):
139 | if rel_index < self.lens[ix]:
140 | break
141 | else:
142 | rel_index -= self.lens[ix]
143 |
144 | #print("sample {} use label:{} relative index: {}".format(index, self.label_path_list[ix], rel_index))
145 |
146 | images, heatmaps = self._get_sample(self.image_dir_list[ix], self.df_list[ix], rel_index)
147 |
148 | return images, heatmaps
149 |
150 |
151 | def _get_sample(self, image_dir, label_data, image_rel_index):
152 | images = []
153 | heatmaps = []
154 |
155 | w = self.imgsz[1]
156 | h = self.imgsz[0]
157 |
158 | rd_state = random.getstate()
159 | rd_seed = time.time()
160 | for i in range(self.sq):
161 | #if (int(label_data['frame_num'][image_rel_index+i]) != int(image_rel_index+i)):
162 | # print(image_dir)
163 | # print(label_data['frame_num'])
164 | # print("{} ---> {}".format(label_data['frame_num'][image_rel_index+i], image_rel_index+i))
165 | # assert(int(label_data['frame_num'][image_rel_index+i]) == int(image_rel_index+i))
166 |
167 |
168 | image_path = image_dir + "/" + str(label_data['frame_num'][image_rel_index+i]) + ".jpg"
169 | img = cv2.imread(image_path) # BGR
170 |
171 | interp = cv2.INTER_LINEAR if (self.augment) else cv2.INTER_AREA
172 | img = cv2.resize(img, (w, h), interpolation=interp)
173 |
174 |
175 | visible = label_data['visible'][image_rel_index+i]
176 | x = label_data['x'][image_rel_index+i]
177 | y = label_data['y'][image_rel_index+i]
178 |
179 | kps_int = np.array([int(w*x), int(h*y), int(visible)]).reshape(1, -1)
180 | kps_xy = kps_int[:, :2]
181 | assert(len(kps_xy) == 1)
182 |
183 | if self.augment:
184 | # 使用系统时间作为种子值, 保证多张图片的增强策略一致
185 | random.seed(rd_seed)
186 |
187 | img, kps_xy = random_perspective(img, kps_xy)
188 |
189 | img, kps_xy = self.albumentations(img, kps_xy)
190 |
191 | augment_hsv(img, hgain=0.015, sgain=0.7, vgain=0.4)
192 |
193 | img, kps_xy = random_flip(img, kps_xy)
194 |
195 | # kps_int will return
196 | kps_int[:, :2] = kps_xy
197 | kps_int = kps_int.astype(int)
198 |
199 |
200 | if visible == 0:
201 | heatmap = self._gen_heatmap(w, h, -1, -1)
202 | else:
203 | heatmap = self._gen_heatmap(w, h, int(w*x), int(h*y))
204 |
205 | # x, y, visible
206 | if kps_int[0][2] == 0:
207 | heatmap = self._gen_heatmap(w, h, -1, -1)
208 | else:
209 | x = kps_int[0][0]
210 | y = kps_int[0][1]
211 |
212 | heatmap = self._gen_heatmap(w, h, x, y)
213 |
214 |
215 | img = ToTensor()(img)
216 |
217 | images.append(img)
218 | heatmaps.append(heatmap)
219 |
220 | random.setstate(rd_state)
221 |
222 | images = torch.concatenate(images) # 平铺RGB维度
223 | heatmaps = torch.tensor(np.array(heatmaps), requires_grad=False, dtype=torch.float32)
224 |
225 | return images, heatmaps
226 |
227 |
228 | def _gen_heatmap(self, w, h, cx, cy, r=2.5, mag=1):
229 | if cx < 0 or cy < 0:
230 | return np.zeros((h, w))
231 | x, y = np.meshgrid(np.linspace(1, w, w), np.linspace(1, h, h))
232 | heatmap = ((y - (cy + 1))**2) + ((x - (cx + 1))**2)
233 | heatmap[heatmap <= r**2] = 1
234 | heatmap[heatmap > r**2] = 0
235 |
236 | return heatmap*mag
237 |
238 |
239 | def _make_gaussian(size=(1920, 1080), center=(0.5, 0.5), fwhm=(5, 5)):
240 | """ Make a square gaussian kernel.
241 |
242 | size: side of the square
243 | center: central point
244 | fwhm: Diameter
245 |
246 | source: https://stackoverflow.com/questions/7687679/how-to-generate-2d-gaussian-with-python
247 | """
248 |
249 | x = np.arange(0, size[0], 1, float)
250 | y = np.arange(0, size[1], 1, float)[:,np.newaxis]
251 |
252 | x0 = size[0]*center[0]
253 | y0 = size[1]*center[1]
254 |
255 | return np.exp(-4*np.log(2) * ((x-x0)**2/fwhm[0]**2 + (y-y0)**2/fwhm[1]**2))
256 |
257 |
258 | if __name__ == "__main__":
259 | if not os.path.exists('./runs/loader_test'):
260 | os.makedirs('./runs/loader_test')
261 |
262 | batch_size = 1
263 | test_loader = create_dataloader("./example_dataset/match/images/1_10_12", batch_size=batch_size)
264 |
265 | for index, (_images,_heatmaps) in enumerate(test_loader):
266 | jj = 0
267 | # for jj in range(batch_size):
268 | hms = []
269 | for ii in range(3):
270 | hm = _heatmaps[jj,ii,:,:].repeat(3,1,1) # 奇怪,为什么不用*255就能直接得到灰度图像
271 | hms.append(hm)
272 | # torchvision.utils.save_image(hms, './runs/loader_test/batch{}_{}_heatmap.png'.format(index, jj))
273 |
274 | ims = []
275 | for ii in range(3):
276 | im = _images[jj,(0+ii*3,1+ii*3,2+ii*3),:,:] # 奇怪,为什么不用*255就能直接得到彩色图像
277 | ims.append(im)
278 | hms.append(im)
279 | # torchvision.utils.save_image(ims, './runs/loader_test/batch{}_{}_image.png'.format(index, jj))
280 |
281 | hms = torchvision.utils.make_grid(hms, nrow=3)
282 | torchvision.utils.save_image(hms, './runs/loader_test/batch{}_{}.png'.format(index, jj))
283 |
284 | if index >= 10:
285 | break
286 |
287 |
288 |
--------------------------------------------------------------------------------
/utils/general.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from pathlib import Path
3 | import cv2
4 | import numpy as np
5 |
6 |
7 |
8 | FILE = Path(__file__).resolve()
9 | ROOT = FILE.parents[1] # YOLOv5 root directory
10 |
11 |
12 | def yaml_load(file):
13 | # Single-line safe yaml loading
14 | with open(file, errors='ignore') as f:
15 | return yaml.safe_load(f)
16 |
17 | def check_dataset(data):
18 | if isinstance(data, (str, Path)):
19 | data = yaml_load(data) # dictionary
20 |
21 | path = Path(data.get('path')) # optional 'path' default to '.'
22 | if not path.is_absolute():
23 | path = (ROOT / path).resolve()
24 | data['path'] = path # download scripts
25 |
26 | for k in 'train', 'val':
27 | if data.get(k): # prepend path
28 | if isinstance(data[k], str):
29 | x = (path / data[k]).resolve()
30 | if not x.exists() and data[k].startswith('../'):
31 | x = (path / data[k][3:]).resolve()
32 | data[k] = str(x)
33 | else:
34 | data[k] = [str((path / x).resolve()) for x in data[k]]
35 |
36 | return data
37 |
38 |
39 | # img: 0/1 binary image, numpy array.
40 | def get_shuttle_position(img):
41 | if np.amax(img) <= 0:
42 | # (visible, cx, cy)
43 | return (0, 0, 0)
44 |
45 | else:
46 | (cnts, _) = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
47 | assert (len(cnts) != 0)
48 |
49 | rects = [cv2.boundingRect(ctr) for ctr in cnts]
50 | max_area_idx = 0
51 | max_area = rects[max_area_idx][2] * rects[max_area_idx][3]
52 |
53 | for ii in range(len(rects)):
54 | area = rects[ii][2] * rects[ii][3]
55 | if area > max_area:
56 | max_area_idx = ii
57 | max_area = area
58 |
59 | target = rects[max_area_idx]
60 | (cx, cy) = (int(target[0] + target[2] / 2), int(target[1] + target[3] / 2))
61 |
62 | # (visible, cx, cy)
63 | return (1, cx, cy)
64 |
65 |
66 | def outcome(y_pred, y_true, tol=3): # [batch, 3, h, w]
67 | n = y_pred.shape[0]
68 | i = 0
69 | TP = TN = FP1 = FP2 = FN = 0
70 | while i < n:
71 | for j in range(3):
72 | if np.amax(y_pred[i][j]) == 0 and np.amax(y_true[i][j]) == 0:
73 | TN += 1
74 | elif np.amax(y_pred[i][j]) > 0 and np.amax(y_true[i][j]) == 0:
75 | FP2 += 1
76 | elif np.amax(y_pred[i][j]) == 0 and np.amax(y_true[i][j]) > 0:
77 | FN += 1
78 | elif np.amax(y_pred[i][j]) > 0 and np.amax(y_true[i][j]) > 0:
79 | h_pred = y_pred[i][j] * 255
80 | h_true = y_true[i][j] * 255
81 | h_pred = h_pred.astype('uint8')
82 | h_true = h_true.astype('uint8')
83 |
84 | (_, cx_pred, cy_pred) = get_shuttle_position(h_pred)
85 | (_, cx_true, cy_true) = get_shuttle_position(h_true)
86 |
87 | dist = np.sqrt(pow(cx_pred-cx_true, 2)+pow(cy_pred-cy_true, 2))
88 |
89 | if dist > tol:
90 | FP1 += 1
91 | else:
92 | TP += 1
93 | i += 1
94 | return (TP, TN, FP1, FP2, FN)
95 |
96 |
97 | def evaluation(TP, TN, FP1, FP2, FN):
98 | try:
99 | accuracy = (TP + TN) / (TP + TN + FP1 + FP2 + FN)
100 | except:
101 | accuracy = 0
102 | try:
103 | precision = TP / (TP + FP1 + FP2)
104 | except:
105 | precision = 0
106 | try:
107 | recall = TP / (TP + FN)
108 | except:
109 | recall = 0
110 |
111 | return (accuracy, precision, recall)
112 |
113 | def tensorboard_log(log_writer, type, avg_loss, TP, TN, FP1, FP2, FN, epoch):
114 | log_writer.add_scalar('{}/loss'.format(type), avg_loss, epoch)
115 | log_writer.add_scalar('{}/TP'.format(type), TP, epoch)
116 | log_writer.add_scalar('{}/TN'.format(type), TN, epoch)
117 | log_writer.add_scalar('{}/FP1'.format(type), FP1, epoch)
118 | log_writer.add_scalar('{}/FP2'.format(type), FP2, epoch)
119 | log_writer.add_scalar('{}/FN'.format(type), FN, epoch)
120 | log_writer.add_scalar('{}/TP'.format(type), TP, epoch)
121 |
122 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN)
123 |
124 | log_writer.add_scalar('{}/Accuracy'.format(type), accuracy, epoch)
125 | log_writer.add_scalar('{}/precision'.format(type), precision, epoch)
126 | log_writer.add_scalar('{}/precision'.format(type), precision, epoch)
--------------------------------------------------------------------------------
/val.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | import os
5 | import sys
6 | import cv2
7 | import numpy as np
8 | from pathlib import Path
9 | from tqdm import tqdm
10 | from argparse import ArgumentParser
11 |
12 | from models.tracknet import TrackNet
13 | from utils.dataloaders import create_dataloader
14 | from utils.general import check_dataset, outcome, evaluation
15 |
16 |
17 | # from yolov5 detect.py
18 | FILE = Path(__file__).resolve()
19 | ABS_ROOT = FILE.parents[0] # YOLOv5 root directory
20 | if str(ABS_ROOT) not in sys.path:
21 | sys.path.append(str(ABS_ROOT)) # add ROOT to PATH
22 | ROOT = Path(os.path.relpath(ABS_ROOT, Path.cwd())) # relative
23 |
24 |
25 |
26 |
27 | def wbce_loss(y_true, y_pred):
28 | return -1*(
29 | ((1-y_pred)**2) * y_true * torch.log(torch.clamp(y_pred, min=1e-07, max=1)) +
30 | (y_pred**2) * (1-y_true) * torch.log(torch.clamp(1-y_pred, min=1e-07, max=1))
31 | ).sum()
32 |
33 |
34 | def validation_loop(device, model, val_loader, save_dir):
35 | model.eval()
36 |
37 | loss_sum = 0
38 | TP = TN = FP1 = FP2 = FN = 0
39 |
40 | with torch.inference_mode():
41 | pbar = tqdm(val_loader, ncols=180)
42 | for batch_index, (X, y) in enumerate(pbar):
43 | X, y = X.to(device), y.to(device)
44 | y_pred = model(X)
45 |
46 | loss_sum += wbce_loss(y, y_pred).item()
47 |
48 | y_ = y.detach().cpu().numpy()
49 | y_pred_ = y_pred.detach().cpu().numpy()
50 |
51 | y_pred_ = (y_pred_ > 0.5).astype('float32')
52 | (tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_)
53 | TP += tp
54 | TN += tn
55 | FP1 += fp1
56 | FP2 += fp2
57 | FN += fn
58 |
59 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN)
60 |
61 | pbar.set_description('Val loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \
62 | loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall))
63 |
64 | F1 = 2 * (precision*recall) / (precision+recall)
65 | print("F1-score: {}".format(F1))
66 |
67 | return loss_sum/len(val_loader)
68 |
69 |
70 |
71 | def parse_opt():
72 | parser = ArgumentParser()
73 |
74 | parser.add_argument('--data', type=str, default=ROOT / 'data/match/test.yaml', help='Path to dataset.')
75 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.')
76 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w')
77 | parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch')
78 | parser.add_argument('--project', default=ROOT / 'runs/val', help='save results to project/name')
79 |
80 | opt = parser.parse_args()
81 | return opt
82 |
83 |
84 | def main(opt):
85 | d_save_dir = str(opt.project)
86 | f_weights = str(opt.weights)
87 | batch_size = opt.batch_size
88 | f_data = str(opt.data)
89 | imgsz = opt.imgsz
90 |
91 |
92 | data_dict = check_dataset(f_data)
93 | train_path, val_path = data_dict['train'], data_dict['val']
94 |
95 | if not os.path.exists(d_save_dir):
96 | os.makedirs(d_save_dir)
97 |
98 |
99 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
100 | model = TrackNet().to(device)
101 |
102 | assert os.path.exists(f_weights), f_weights+" is invalid"
103 | print("load pretrain weights {}".format(f_weights))
104 | model.load_state_dict(torch.load(f_weights))
105 |
106 |
107 | val_loader = create_dataloader(val_path, imgsz, batch_size=batch_size)
108 |
109 | validation_loop(device, model, val_loader, d_save_dir)
110 |
111 |
112 | if __name__ == '__main__':
113 | opt = parse_opt()
114 | main(opt)
--------------------------------------------------------------------------------