├── .gitignore ├── README.md ├── data └── test.yaml ├── deploy ├── app.py └── test_app.py ├── detect.py ├── example_dataset └── match │ ├── images │ └── 1_10_12 │ │ ├── 0.jpg │ │ ├── 1.jpg │ │ ├── 2.jpg │ │ ├── 3.jpg │ │ ├── 4.jpg │ │ └── 5.jpg │ ├── labels │ └── 1_10_12.csv │ └── videos │ └── 1_10_12.mp4 ├── models └── tracknet.py ├── ncnn_inference ├── Makefile ├── convert_pnnx.py └── tracknet.cpp ├── requirements.txt ├── runs └── .gitkeep ├── tf2torch ├── diff.txt ├── onnx2pt.py └── track.pt ├── tools ├── Frame_Generator.py ├── Frame_Generator_batch.py ├── Frame_Generator_rally.py ├── check_labels.py ├── handle_Darklabel.py ├── handle_tracknet_dataset.py └── label_tool.py ├── train.py ├── utils ├── augmentations.py ├── dataloaders.py └── general.py └── val.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | runs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TrackNetV2-pytorch 2 | 3 | Paper: **TrackNetV2: Efficient Shuttlecock Tracking Network** 4 | 5 | Original Project(tensorflow): https://nol.cs.nctu.edu.tw:234/open-source/TrackNetv2 6 | 7 | > 官方上传的标注工具、数据集均已失效。del> 8 | > 9 | > The author has now reuploaded the dataset。 10 | 11 | Paper reading:[TrackNetV2论文记录与pytorch复现](https://zhuanlan.zhihu.com/p/624900770) 12 | 13 | 14 | 15 | ## Inference with pytorch weights converted from tensorflow weights: 16 | 17 | ```shell 18 | git apply tf2torch/diff.txt 19 | python detect.py --source xxx.mp4 --weights ./tf2torch/track.pt --view-img # TrackNetv2/3_in_3_out/model906_30 20 | ``` 21 | 22 | 23 | 24 | ## Inference: 25 | 26 | ``` 27 | python detect.py --source xxx.mp4 --weights xxx.pt --view-img 28 | ``` 29 | 30 | 31 | 32 | ## Training: 33 | 34 | ``` 35 | # training from scratch 36 | python train.py --data data/match.yaml 37 | 38 | # training from pretrain weight 39 | python train.py --weights xxx.pt --data data/match.yaml 40 | 41 | # resume training 42 | python train.py --data data/match.yaml --resume 43 | ``` 44 | 45 | 46 | 47 | ## Evaluation: 48 | 49 | ```shell 50 | python val.py --weights xxx.pt --data data/match.yaml 51 | ``` 52 | 53 | 54 | 55 | ## Deployment: 56 | 57 | ```shell 58 | # Server 59 | python deploy/app.py --weights xxx.pt 60 | 61 | # Client 62 | python deploy/test_app.py 63 | ``` 64 | 65 | 66 | 67 | 68 | 69 | ## Dataset Preparation: 70 | 71 | ``` 72 | # TrackNetV2 dataset 73 | # /home/chg/Badminton/TrackNetV2 74 | # - Amateur 75 | # - Professional 76 | # - Test 77 | 78 | python tools/handle_tracknet_dataset.py /home/chg/Badminton/TrackNetV2/Amateur 79 | python tools/handle_tracknet_dataset.py /home/chg/Badminton/TrackNetV2/Professional 80 | python tools/handle_tracknet_dataset.py /home/chg/Badminton/TrackNetV2/Test 81 | 82 | python tools/Frame_Generator_rally.py /home/chg/Badminton/TrackNetV2/Amateur 83 | python tools/Frame_Generator_rally.py /home/chg/Badminton/TrackNetV2/Professional 84 | python tools/Frame_Generator_rally.py /home/chg/Badminton/TrackNetV2/Test 85 | 86 | 87 | # TrackNetV2 dataset config : data/match.yaml 88 | path: /home/chg/Documents/Badminton/TrackNetV2 89 | train: 90 | - Amateur 91 | - Professional 92 | val: 93 | - Test 94 | 95 | # also you can use follow config for testing 96 | train: 97 | - Test/match1/images/1_05_02 98 | val: 99 | - Test/match2/images/1_03_03 100 | 101 | # or 102 | train: 103 | - Test/match1 104 | val: 105 | - Test/match2 106 | 107 | ``` 108 | 109 | 110 | 111 | ## Reference: 112 | 113 | https://github.com/mareksubocz/TrackNet 114 | 115 | https://nol.cs.nctu.edu.tw:234/open-source/TrackNetv2 116 | 117 | https://github.com/ultralytics/yolov5 -------------------------------------------------------------------------------- /data/test.yaml: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, GPL-3.0 license 2 | # COCO 2017 dataset http://cocodataset.org 3 | # Example usage: python train.py --data coco.yaml 4 | # parent 5 | # ├── yolov5 6 | # └── datasets 7 | # └── battle ← downloads here 8 | 9 | 10 | # Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..] 11 | path: ./example_dataset/match # dataset root dir 12 | train: 13 | - images/1_10_12 14 | 15 | val: 16 | - images/1_10_12 17 | 18 | # Classes 19 | # nc: 3 # number of classes 20 | # names: ['mcar-0', 'mpeo-1', 'peo-2'] 21 | -------------------------------------------------------------------------------- /deploy/app.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | 4 | import sys 5 | from pathlib import Path 6 | FILE = Path(__file__).resolve() 7 | ROOT = FILE.parents[1] # 1 -> root directory 8 | if str(ROOT) not in sys.path: 9 | sys.path.append(str(ROOT)) # add ROOT to PATH 10 | 11 | import cv2 12 | import tempfile 13 | import numpy as np 14 | from argparse import ArgumentParser 15 | 16 | import torch 17 | from torchvision import models 18 | import torchvision.transforms as transforms 19 | from PIL import Image 20 | from flask import Flask, jsonify, request 21 | 22 | from models.tracknet import TrackNet 23 | from utils.general import get_shuttle_position 24 | 25 | # reference: https://pytorch.org/tutorials/intermediate/flask_rest_api_tutorial.html 26 | 27 | 28 | def prediction(video, model, device, imgsz): 29 | vid_cap = cv2.VideoCapture(video) 30 | fps = vid_cap.get(cv2.CAP_PROP_FPS) 31 | w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 32 | h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 33 | 34 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 35 | out = cv2.VideoWriter("./predict.mp4", fourcc, fps, (w, h)) 36 | 37 | count = 0 38 | video_end = False 39 | while vid_cap.isOpened(): 40 | imgs = [] 41 | for _ in range(3): 42 | ret, img = vid_cap.read() 43 | 44 | if not ret: 45 | video_end = True 46 | break 47 | imgs.append(img) 48 | 49 | if video_end: 50 | break 51 | 52 | imgs_torch = [] 53 | for img in imgs: 54 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 55 | 56 | img_torch = transforms.ToTensor()(img).to(device) # already [0, 1] 57 | img_torch = transforms.functional.resize(img_torch, imgsz, antialias=True) 58 | 59 | imgs_torch.append(img_torch) 60 | 61 | imgs_torch = torch.cat(imgs_torch, dim=0).unsqueeze(0) 62 | 63 | preds = model(imgs_torch) 64 | preds = preds[0].detach().cpu().numpy() 65 | 66 | y_preds = preds > 0.5 67 | y_preds = y_preds.astype('float32') 68 | y_preds = y_preds*255 69 | y_preds = y_preds.astype('uint8') 70 | 71 | for i in range(3): 72 | (visible, cx_pred, cy_pred) = get_shuttle_position(y_preds[i]) 73 | (cx, cy) = (int(cx_pred*w/imgsz[1]), int(cy_pred*h/imgsz[0])) 74 | if visible: 75 | cv2.circle(imgs[i], (cx, cy), 8, (0,0,255), -1) 76 | 77 | 78 | out.write(imgs[i]) 79 | print("{} ---- visible: {} cx: {} cy: {}".format(count, visible, cx, cy)) 80 | 81 | count += 1 82 | 83 | out.release() 84 | vid_cap.release() 85 | 86 | 87 | 88 | 89 | 90 | def parse_opt(): 91 | parser = ArgumentParser() 92 | 93 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w') 94 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.') 95 | 96 | opt = parser.parse_args() 97 | 98 | return opt 99 | 100 | 101 | def main(opt): 102 | f_weights = str(opt.weights) 103 | imgsz = opt.imgsz 104 | 105 | device = "cuda" 106 | 107 | model = TrackNet().to(device) 108 | model.load_state_dict(torch.load(f_weights)) 109 | model.eval() 110 | print("initialize TrackNet, load weights: {}".format(f_weights)) 111 | 112 | app = Flask(__name__) 113 | 114 | @app.route('/predict', methods=['POST']) 115 | def predict(): 116 | if request.method == 'POST': 117 | file = request.files['file'] 118 | 119 | # file.save('video.mp4') 120 | file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8) 121 | 122 | with tempfile.NamedTemporaryFile() as temp: 123 | temp.write(file_bytes) 124 | 125 | prediction(temp.name, model, device, imgsz) 126 | 127 | return 'Video processed successfully' 128 | 129 | app.run() 130 | 131 | 132 | if __name__ == '__main__': 133 | opt = parse_opt() 134 | main(opt) -------------------------------------------------------------------------------- /deploy/test_app.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | data = {'file': open('/home/chg/Videos/ld/2012.mp4', 'rb')} 4 | 5 | resp = requests.post("http://localhost:5000/predict", files=data) 6 | print(resp.text) -------------------------------------------------------------------------------- /detect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import numpy as np 8 | from pathlib import Path 9 | from argparse import ArgumentParser 10 | 11 | from models.tracknet import TrackNet 12 | from utils.general import get_shuttle_position 13 | 14 | # from yolov5 detect.py 15 | FILE = Path(__file__).resolve() 16 | ROOT = FILE.parents[0] # YOLOv5 root directory 17 | if str(ROOT) not in sys.path: 18 | sys.path.append(str(ROOT)) # add ROOT to PATH 19 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 20 | 21 | 22 | def parse_opt(): 23 | parser = ArgumentParser() 24 | 25 | parser.add_argument('--source', type=str, default=ROOT / 'example_dataset/match/videos/1_10_12.mp4', help='Path to video.') 26 | parser.add_argument('--save-txt', action='store_true', help='save results to *.csv') 27 | parser.add_argument('--view-img', action='store_true', help='show results') 28 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w') 29 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.') 30 | parser.add_argument('--project', default=ROOT / 'runs/detect', help='save results to project/name') 31 | 32 | opt = parser.parse_args() 33 | 34 | return opt 35 | 36 | 37 | def main(opt): 38 | # imgsz = [288, 512] 39 | # imgsz = [360, 640] 40 | 41 | source_name = os.path.splitext(os.path.basename(opt.source))[0] 42 | b_save_txt = opt.save_txt 43 | b_view_img = opt.view_img 44 | d_save_dir = str(opt.project) 45 | f_weights = str(opt.weights) 46 | f_source = str(opt.source) 47 | imgsz = opt.imgsz 48 | 49 | # video_name ---> video_name_pred 50 | source_name = '{}_predict'.format(source_name) 51 | 52 | # runs/detect 53 | if not os.path.exists(d_save_dir): 54 | os.makedirs(d_save_dir) 55 | 56 | # runs/detect/video_name 57 | img_save_path = '{}/{}'.format(d_save_dir, source_name) 58 | if not os.path.exists(img_save_path): 59 | os.makedirs(img_save_path) 60 | 61 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 62 | 63 | model = TrackNet().to(device) 64 | model.load_state_dict(torch.load(f_weights)) 65 | model.eval() 66 | 67 | # import ncnn 68 | # net = ncnn.Net() 69 | # net.load_param("./pt_30_optimize.ncnn.param") 70 | # net.load_model("./pt_30_optimize.ncnn.bin") 71 | 72 | vid_cap = cv2.VideoCapture(f_source) 73 | video_end = False 74 | 75 | video_len = vid_cap.get(cv2.CAP_PROP_FRAME_COUNT) 76 | fps = vid_cap.get(cv2.CAP_PROP_FPS) 77 | w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 78 | h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 79 | 80 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 81 | out = cv2.VideoWriter('{}/{}.mp4'.format(d_save_dir, source_name), fourcc, fps, (w, h)) 82 | 83 | if b_save_txt: 84 | f_save_txt = open('{}/{}.csv'.format(d_save_dir, source_name), 'w') 85 | f_save_txt.write('frame_num,visible,x,y\n') 86 | 87 | if b_view_img: 88 | cv2.namedWindow(source_name, cv2.WINDOW_NORMAL) 89 | cv2.resizeWindow(source_name, (w, h)) 90 | 91 | count = 0 92 | while vid_cap.isOpened(): 93 | imgs = [] 94 | for _ in range(3): 95 | ret, img = vid_cap.read() 96 | 97 | if not ret: 98 | video_end = True 99 | break 100 | imgs.append(img) 101 | 102 | if video_end: 103 | break 104 | 105 | imgs_torch = [] 106 | for img in imgs: 107 | # https://www.geeksforgeeks.org/converting-an-image-to-a-torch-tensor-in-python/ 108 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 109 | 110 | img_torch = torchvision.transforms.ToTensor()(img).to(device) # already [0, 1] 111 | img_torch = torchvision.transforms.functional.resize(img_torch, imgsz, antialias=True) 112 | 113 | imgs_torch.append(img_torch) 114 | 115 | imgs_torch = torch.cat(imgs_torch, dim=0).unsqueeze(0) 116 | 117 | preds = model(imgs_torch) 118 | preds = preds[0].detach().cpu().numpy() 119 | 120 | # ncnn 121 | # ex = net.create_extractor() 122 | # ex.input("in0", ncnn.Mat(imgs_torch.squeeze(0).numpy()).clone()) 123 | # _, out0 = ex.extract("out0") 124 | # preds = np.array(out0) 125 | 126 | y_preds = preds > 0.5 127 | y_preds = y_preds.astype('float32') 128 | y_preds = y_preds*255 129 | y_preds = y_preds.astype('uint8') 130 | 131 | for i in range(3): 132 | (visible, cx_pred, cy_pred) = get_shuttle_position(y_preds[i]) 133 | (cx, cy) = (int(cx_pred*w/imgsz[1]), int(cy_pred*h/imgsz[0])) 134 | if visible: 135 | cv2.circle(imgs[i], (cx, cy), 8, (0,0,255), -1) 136 | 137 | if b_save_txt: 138 | f_save_txt.write('{},{},{},{}\n'.format(count, visible, cx, cy)) 139 | 140 | if b_view_img: 141 | cv2.imwrite('{}/{}.png'.format(img_save_path, count), imgs[i]) 142 | cv2.imshow(source_name, imgs[i]) 143 | cv2.waitKey(1) 144 | 145 | out.write(imgs[i]) 146 | print("{} ---- visible: {} cx: {} cy: {}".format(count, visible, cx, cy)) 147 | 148 | count += 1 149 | 150 | if cv2.waitKey(1) & 0xFF == ord('q'): 151 | break 152 | 153 | if b_save_txt: 154 | # 每次识别3张,最后可能有1-2张没有识别,补0 155 | while count < video_len: 156 | f_save_txt.write('{},0,0,0\n'.format(count)) 157 | count += 1 158 | 159 | f_save_txt.close() 160 | 161 | out.release() 162 | vid_cap.release() 163 | cv2.destroyAllWindows() 164 | 165 | 166 | if __name__ == '__main__': 167 | opt = parse_opt() 168 | main(opt) 169 | -------------------------------------------------------------------------------- /example_dataset/match/images/1_10_12/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/0.jpg -------------------------------------------------------------------------------- /example_dataset/match/images/1_10_12/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/1.jpg -------------------------------------------------------------------------------- /example_dataset/match/images/1_10_12/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/2.jpg -------------------------------------------------------------------------------- /example_dataset/match/images/1_10_12/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/3.jpg -------------------------------------------------------------------------------- /example_dataset/match/images/1_10_12/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/4.jpg -------------------------------------------------------------------------------- /example_dataset/match/images/1_10_12/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/images/1_10_12/5.jpg -------------------------------------------------------------------------------- /example_dataset/match/labels/1_10_12.csv: -------------------------------------------------------------------------------- 1 | frame_num,visible,x,y 2 | 0,1,0.6442708333333333,0.424074074074074 3 | 1,1,0.64375,0.4185185185185185 4 | 2,1,0.6427083333333333,0.4175925925925925 5 | 3,1,0.6421875,0.412962962962963 6 | 4,1,0.6416666666666667,0.412037037037037 7 | 5,1,0.6401041666666667,0.3944444444444444 8 | -------------------------------------------------------------------------------- /example_dataset/match/videos/1_10_12.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/example_dataset/match/videos/1_10_12.mp4 -------------------------------------------------------------------------------- /models/tracknet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torchsummary import summary 5 | 6 | 7 | class Conv(nn.Module): 8 | def __init__(self, ic, oc, k=(3,3), p="same", act=True): 9 | super().__init__() 10 | self.conv = nn.Conv2d(ic, oc, kernel_size=k, padding=p) 11 | self.bn = nn.BatchNorm2d(oc) 12 | self.act = nn.ReLU() if act else nn.Identity() 13 | 14 | # self.convs = nn.Sequential( 15 | # nn.Conv2d(ic, oc, kernel_size=k, padding=p), 16 | # nn.ReLU(), 17 | # nn.BatchNorm2d(oc), 18 | # ) 19 | 20 | def forward(self, x): 21 | return self.bn(self.act(self.conv(x))) # 和relu-bn-conv不一样? 22 | #return self.convs(x) 23 | 24 | 25 | class TrackNet(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | # VGG16 30 | # self.conv2d_1 = Conv(3, 64) 输入3张灰度图 31 | self.conv2d_1 = Conv(9, 64) # 输入3张RGB图 32 | self.conv2d_2 = Conv(64, 64) 33 | self.max_pooling_1 = nn.MaxPool2d((2,2), stride=(2,2)) 34 | 35 | self.conv2d_3 = Conv(64, 128) 36 | self.conv2d_4 = Conv(128, 128) 37 | self.max_pooling_2 = nn.MaxPool2d((2,2), stride=(2,2)) 38 | 39 | self.conv2d_5 = Conv(128, 256) 40 | self.conv2d_6 = Conv(256, 256) 41 | self.conv2d_7 = Conv(256, 256) 42 | self.max_pooling_3 = nn.MaxPool2d((2,2), stride=(2,2)) 43 | 44 | self.conv2d_8 = Conv(256, 512) 45 | self.conv2d_9 = Conv(512, 512) 46 | self.conv2d_10 = Conv(512, 512) 47 | 48 | # Deconv / UNet 49 | self.up_sampling_1 = nn.UpsamplingNearest2d(scale_factor=2) 50 | # concatenate_1 with conv2d_7, axis = 1 51 | 52 | self.conv2d_11 = Conv(768, 256) 53 | self.conv2d_12 = Conv(256, 256) 54 | self.conv2d_13 = Conv(256, 256) 55 | 56 | self.up_sampling_2 = nn.UpsamplingNearest2d(scale_factor=2) 57 | # concatenate_2 with conv2d_4, axis = 1 58 | 59 | self.conv2d_14 = Conv(384, 128) 60 | self.conv2d_15 = Conv(128, 128) 61 | 62 | self.up_sampling_3 = nn.UpsamplingNearest2d(scale_factor=2) 63 | # concatenate_3 with conv2d_2, axis = 1 64 | 65 | self.conv2d_16 = Conv(192, 64) 66 | self.conv2d_17 = Conv(64, 64) 67 | self.conv2d_18 = nn.Conv2d(64, 3, kernel_size=(1,1), padding='same') # 输出3张图 68 | # self.conv2d_18 = Conv(64, 1, k=(1,1)) 输出1张图 69 | 70 | 71 | def forward(self, x): 72 | # VGG16 73 | x = self.conv2d_1(x) 74 | x1 = self.conv2d_2(x) 75 | x = self.max_pooling_1(x1) 76 | 77 | x = self.conv2d_3(x) 78 | x2 = self.conv2d_4(x) 79 | x = self.max_pooling_2(x2) 80 | 81 | x = self.conv2d_5(x) 82 | x = self.conv2d_6(x) 83 | x3 = self.conv2d_7(x) 84 | x = self.max_pooling_3(x3) 85 | 86 | x = self.conv2d_8(x) 87 | x = self.conv2d_9(x) 88 | x = self.conv2d_10(x) 89 | 90 | # Deconv / UNet 91 | x = self.up_sampling_1(x) 92 | x = torch.concat([x, x3], dim=1) 93 | 94 | x = self.conv2d_11(x) 95 | x = self.conv2d_12(x) 96 | x = self.conv2d_13(x) 97 | 98 | x = self.up_sampling_2(x) 99 | x = torch.concat([x, x2], dim=1) 100 | 101 | x = self.conv2d_14(x) 102 | x = self.conv2d_15(x) 103 | 104 | x = self.up_sampling_3(x) 105 | x = torch.concat([x, x1], dim=1) 106 | 107 | x = self.conv2d_16(x) 108 | x = self.conv2d_17(x) 109 | x = self.conv2d_18(x) 110 | 111 | x = torch.sigmoid(x) 112 | 113 | return x 114 | 115 | 116 | # torch.load 加载权重 117 | # model.load_state_dict 将权重加载到模型中 118 | 119 | # model.state_dict() 模型的权重 120 | # torch.save 保存模型的权重 121 | 122 | 123 | if __name__ == '__main__': 124 | model = TrackNet() 125 | print(summary(model,(9, 288, 512), device="cpu")) 126 | #print(model) 127 | -------------------------------------------------------------------------------- /ncnn_inference/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | g++ -o tracknet tracknet.cpp -lncnn -I../../ncnn-20240820-ubuntu-2204-shared/include/ncnn -L../../ncnn-20240820-ubuntu-2204-shared/lib/ `pkg-config --cflags --libs opencv4` 3 | 4 | #export LD_LIBRARY_PATH=../../ncnn-20240820-ubuntu-2204-shared/lib/ 5 | #./tracknet ../smash_test.mp4 -------------------------------------------------------------------------------- /ncnn_inference/convert_pnnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pnnx 3 | 4 | from pathlib import Path 5 | import sys 6 | 7 | FILE = Path(__file__).resolve() 8 | ROOT = FILE.parents[1] 9 | 10 | if str(ROOT) not in sys.path: 11 | sys.path.append(str(ROOT)) # add ROOT to PATH 12 | 13 | from models.tracknet import TrackNet 14 | 15 | 16 | model = TrackNet().to("cpu") 17 | model.load_state_dict(torch.load("./last.pt")) 18 | model = model.eval() 19 | 20 | x = torch.rand(1, 9, 288, 512) 21 | 22 | opt_model = pnnx.export(model, "./last_opt.pt", x, fp16 = False) 23 | -------------------------------------------------------------------------------- /ncnn_inference/tracknet.cpp: -------------------------------------------------------------------------------- 1 | // Tencent is pleased to support the open source community by making ncnn available. 2 | // 3 | // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. 4 | // 5 | // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6 | // in compliance with the License. You may obtain a copy of the License at 7 | // 8 | // https://opensource.org/licenses/BSD-3-Clause 9 | // 10 | // Unless required by applicable law or agreed to in writing, software distributed 11 | // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12 | // CONDITIONS OF ANY KIND, either express or implied. See the License for the 13 | // specific language governing permissions and limitations under the License. 14 | 15 | #include "net.h" 16 | 17 | #include 18 | #include 19 | #include 20 | // #include 21 | 22 | #include 23 | #include 24 | #include 25 | #include 26 | 27 | cv::Mat hwc2chw(const cv::Mat& src_mat9) 28 | { 29 | std::vector bgr_channels(9); 30 | cv::split(src_mat9, bgr_channels); 31 | for (size_t i = 0; i < bgr_channels.size(); i++) 32 | { 33 | bgr_channels[i] = bgr_channels[i].reshape(1, 1); // reshape为1通道,1行,n列 34 | } 35 | cv::Mat dst_mat; 36 | cv::hconcat(bgr_channels, dst_mat); 37 | return dst_mat; 38 | } 39 | 40 | 41 | void print_shape(const ncnn::Mat &in) 42 | { 43 | std::cout << "d: " << in.d << " c: " << in.c << " w: " << in.w << " h: " << in.h << " cstep: " << in.cstep << std::endl; 44 | } 45 | 46 | std::tuple get_shuttle_position(const cv::Mat binary_pred) 47 | { 48 | if (cv::countNonZero(binary_pred) <= 0) 49 | { 50 | // (visible, cx, cy) 51 | return std::make_tuple(0, 0, 0); 52 | } 53 | else 54 | { 55 | std::vector> cnts; 56 | cv::findContours(binary_pred, cnts, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); 57 | assert(cnts.size()!= 0); 58 | 59 | std::vector rects; 60 | for (const auto& ctr : cnts) { 61 | rects.push_back(cv::boundingRect(ctr)); 62 | } 63 | 64 | int max_area_idx = 0; 65 | int max_area = rects[max_area_idx].width * rects[max_area_idx].height; 66 | 67 | for (size_t ii = 0; ii < rects.size(); ++ii) { 68 | int area = rects[ii].width * rects[ii].height; 69 | if (area > max_area) { 70 | max_area_idx = ii; 71 | max_area = area; 72 | } 73 | } 74 | 75 | cv::Rect target = rects[max_area_idx]; 76 | int cx = target.x + target.width / 2; 77 | int cy = target.y + target.height / 2; 78 | 79 | // (visible, cx, cy) 80 | return std::make_tuple(1, cx, cy); 81 | } 82 | } 83 | 84 | 85 | static int detect_tracknet(const char* video_path) 86 | { 87 | ncnn::Net tracknet; 88 | 89 | // GPU 90 | tracknet.opt.use_vulkan_compute = false; 91 | 92 | if(tracknet.load_param("./last_opt.ncnn.param")) 93 | exit(-1); 94 | if(tracknet.load_model("./last_opt.ncnn.bin")) 95 | exit(-1); 96 | 97 | 98 | cv::VideoCapture vid_cap(video_path); 99 | bool video_end = false; 100 | 101 | int video_len = vid_cap.get(cv::CAP_PROP_FRAME_COUNT); 102 | double fps = vid_cap.get(cv::CAP_PROP_FPS); 103 | int w = vid_cap.get(cv::CAP_PROP_FRAME_WIDTH); 104 | int h = vid_cap.get(cv::CAP_PROP_FRAME_HEIGHT); 105 | 106 | int iw = 512; 107 | int ih = 288; 108 | int ic = 9; 109 | 110 | int count = 0; 111 | while (vid_cap.isOpened()) 112 | { 113 | std::vector imgs; 114 | for (int i = 0; i < 3; ++i) 115 | { 116 | cv::Mat img; 117 | bool ret = vid_cap.read(img); 118 | if (!ret) 119 | { 120 | video_end = true; 121 | break; 122 | } 123 | 124 | imgs.push_back(img); 125 | } 126 | 127 | if (video_end) 128 | break; 129 | 130 | 131 | auto start = std::chrono::high_resolution_clock::now(); 132 | std::vector imgs_hwc; 133 | for (int i=0; i<3; i++) 134 | { 135 | cv::Mat img; 136 | imgs[i].convertTo(img, CV_32F, 1.0 / 255.0); 137 | 138 | cv::resize(img, img, cv::Size(iw, ih), 0, 0, cv::INTER_LINEAR); 139 | 140 | std::vector bgr_channels(3); 141 | cv::split(img, bgr_channels); 142 | 143 | // 210: bgr - > rgb 144 | imgs_hwc.push_back(bgr_channels[2].reshape(1, 1)); 145 | imgs_hwc.push_back(bgr_channels[1].reshape(1, 1)); 146 | imgs_hwc.push_back(bgr_channels[0].reshape(1, 1)); 147 | } 148 | 149 | // inference need chw ! 150 | cv::Mat imgs_chw; 151 | cv::hconcat(imgs_hwc, imgs_chw); 152 | 153 | // d = 1,c = 9, w = 512, h = 288, cstep = 147456 154 | ncnn::Mat in(iw, ih, ic, (void*)imgs_chw.data); 155 | // print_shape(in); 156 | std::chrono::duration elapsed = std::chrono::high_resolution_clock::now() - start; 157 | std::cout << "Preprocess time taken: " << elapsed.count() << " ms" << std::endl; 158 | 159 | start = std::chrono::high_resolution_clock::now(); 160 | ncnn::Extractor ex = tracknet.create_extractor(); 161 | ex.input("in0", in); 162 | 163 | ncnn::Mat out; 164 | ex.extract("out0", out); 165 | elapsed = std::chrono::high_resolution_clock::now() - start; 166 | std::cout << "Inference time taken: " << elapsed.count() << " ms" << std::endl; 167 | 168 | // post process 169 | // c = 3 170 | start = std::chrono::high_resolution_clock::now(); 171 | for (int i=0; i NWHC 31 | + x = self.bn(x) 32 | + x = x.transpose(1, 3) # NCHW <--- NWHC 33 | + 34 | + return x 35 | 36 | 37 | class TrackNet(nn.Module): 38 | @@ -28,42 +28,42 @@ class TrackNet(nn.Module): 39 | 40 | # VGG16 41 | # self.conv2d_1 = Conv(3, 64) 输入3张灰度图 42 | - self.conv2d_1 = Conv(9, 64) # 输入3张RGB图 43 | - self.conv2d_2 = Conv(64, 64) 44 | + self.conv2d_1 = Conv(9, 64, 512) # 输入3张RGB图 45 | + self.conv2d_2 = Conv(64, 64, 512) 46 | self.max_pooling_1 = nn.MaxPool2d((2,2), stride=(2,2)) 47 | 48 | - self.conv2d_3 = Conv(64, 128) 49 | - self.conv2d_4 = Conv(128, 128) 50 | + self.conv2d_3 = Conv(64, 128, 256) 51 | + self.conv2d_4 = Conv(128, 128, 256) 52 | self.max_pooling_2 = nn.MaxPool2d((2,2), stride=(2,2)) 53 | 54 | - self.conv2d_5 = Conv(128, 256) 55 | - self.conv2d_6 = Conv(256, 256) 56 | - self.conv2d_7 = Conv(256, 256) 57 | + self.conv2d_5 = Conv(128, 256, 128) 58 | + self.conv2d_6 = Conv(256, 256, 128) 59 | + self.conv2d_7 = Conv(256, 256, 128) 60 | self.max_pooling_3 = nn.MaxPool2d((2,2), stride=(2,2)) 61 | 62 | - self.conv2d_8 = Conv(256, 512) 63 | - self.conv2d_9 = Conv(512, 512) 64 | - self.conv2d_10 = Conv(512, 512) 65 | + self.conv2d_8 = Conv(256, 512, 64) 66 | + self.conv2d_9 = Conv(512, 512, 64) 67 | + self.conv2d_10 = Conv(512, 512, 64) 68 | 69 | # Deconv / UNet 70 | self.up_sampling_1 = nn.UpsamplingNearest2d(scale_factor=2) 71 | # concatenate_1 with conv2d_7, axis = 1 72 | 73 | - self.conv2d_11 = Conv(768, 256) 74 | - self.conv2d_12 = Conv(256, 256) 75 | - self.conv2d_13 = Conv(256, 256) 76 | + self.conv2d_11 = Conv(768, 256, 128) 77 | + self.conv2d_12 = Conv(256, 256, 128) 78 | + self.conv2d_13 = Conv(256, 256, 128) 79 | 80 | self.up_sampling_2 = nn.UpsamplingNearest2d(scale_factor=2) 81 | # concatenate_2 with conv2d_4, axis = 1 82 | 83 | - self.conv2d_14 = Conv(384, 128) 84 | - self.conv2d_15 = Conv(128, 128) 85 | + self.conv2d_14 = Conv(384, 128, 256) 86 | + self.conv2d_15 = Conv(128, 128, 256) 87 | 88 | self.up_sampling_3 = nn.UpsamplingNearest2d(scale_factor=2) 89 | # concatenate_3 with conv2d_2, axis = 1 90 | 91 | - self.conv2d_16 = Conv(192, 64) 92 | - self.conv2d_17 = Conv(64, 64) 93 | + self.conv2d_16 = Conv(192, 64, 512) 94 | + self.conv2d_17 = Conv(64, 64, 512) 95 | self.conv2d_18 = nn.Conv2d(64, 3, kernel_size=(1,1), padding='same') # 输出3张图 96 | # self.conv2d_18 = Conv(64, 1, k=(1,1)) 输出1张图 97 | 98 | -------------------------------------------------------------------------------- /tf2torch/onnx2pt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import onnx 3 | from onnx2torch import convert 4 | from models.tracknet import TrackNet 5 | from torchsummary import summary 6 | 7 | # tensorflow weight 8 | #https://nol.cs.nctu.edu.tw:234/open-source/TrackNetv2/tree/master/3_in_3_out/model906_30 9 | 10 | # model906_30 ---> mode_save ---> track.onnx 11 | # 12 | 13 | # Path to ONNX model 14 | onnx_model_path = './track.onnx' 15 | # You can pass the path to the onnx model to convert it or... 16 | 17 | 18 | # onnx_model = onnx.load(onnx_model_path) 19 | # # https://stackoverflow.com/questions/53176229/strip-onnx-graph-from-its-constants-initializers 20 | # onnx_model.graph.ClearField('initializer') 21 | # torch_model = convert(onnx_model) 22 | torch_model = convert(onnx_model_path) 23 | 24 | 25 | onnx_dict = torch_model.state_dict() 26 | 27 | del onnx_dict['initializers.onnx_initializer_0'] 28 | del onnx_dict['initializers.onnx_initializer_1'] 29 | del onnx_dict['initializers.onnx_initializer_2'] 30 | del onnx_dict['initializers.onnx_initializer_3'] 31 | del onnx_dict['initializers.onnx_initializer_4'] 32 | del onnx_dict['initializers.onnx_initializer_5'] 33 | 34 | 35 | 36 | 37 | track_model = TrackNet() 38 | # print(summary(track_model,(9, 288, 512), device="cpu")) 39 | track_dict = track_model.state_dict() 40 | 41 | 42 | 43 | assert(len(onnx_dict)==len(track_dict)) 44 | 45 | convert_dict = {} 46 | # print(track_dict.keys()) 47 | # print(onnx_dict.keys()) 48 | 49 | for k1, k2 in zip(onnx_dict.keys(), track_dict.keys()): 50 | # print(f'onnx: {k1} shape:{onnx_dict[k1].shape} track: {k2} shape:{track_dict[k2].shape}') 51 | convert_dict[k2] = onnx_dict[k1] 52 | 53 | 54 | torch.save(convert_dict, './track.pt') 55 | 56 | 57 | 58 | 59 | # ValueError: Unknown layer: BatchNormalization. 60 | # ValueError: Unknown loss function: custom_loss. 61 | 62 | # /home/chg/anaconda3/envs/track/lib/python3.10/site-packages/tf2onnx/tf_loader.py 63 | 64 | # line 645 65 | # def custom_loss(y_true, y_pred): 66 | # from tensorflow.python.keras import backend as K 67 | # loss = (-1)*(K.square(1 - y_pred) * y_true * K.log(K.clip(y_pred, K.epsilon(), 1)) + K.square(y_pred) * (1 - y_true) * K.log(K.clip(1 - y_pred, K.epsilon(), 1))) 68 | # return K.mean(loss) 69 | 70 | # def from_keras(model_path, input_names, output_names): 71 | # ... 72 | # custom_objects = {'custom_loss':custom_loss} 73 | 74 | 75 | # size mismatch for conv2d_1.bn.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). 76 | # size mismatch for conv2d_1.bn.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([64]). 77 | # onnx: BatchNormalization_0.weight shape:torch.Size([512]) track: conv2d_1.bn.weight shape:torch.Size([64]) 78 | 79 | # tensorflow中conv2d默认channels_last,默认情况下BatchNormalization会对最后一个通道C通道进行BN,而使用channels_first后, 80 | # BatchNormalization 输入(N, C, H, W), 学习参数, C 81 | 82 | 83 | # https://keras.io/api/layers/normalization_layers/batch_normalization/ 84 | # 在keras和tensorflow中,BatchNormalization默认是对输入的最后一个维度进行归一化,即axis=-1,这通常是特征维度。在pytorch中,BatchNormalization有不同的变体,如BatchNorm1d,BatchNorm2d等,它们默认是对输入的第二个维度进行归一化,即num_features。 -------------------------------------------------------------------------------- /tf2torch/track.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChgygLin/TrackNetV2-pytorch/185674ff3d97ef66f3f34ef111a705fd846dd402/tf2torch/track.pt -------------------------------------------------------------------------------- /tools/Frame_Generator.py: -------------------------------------------------------------------------------- 1 | # 将单个视频分解成图像帧 2 | # python Frame_Generator.py Test/match1/videos/1_05_02.mp4 Test/match1/images/1_05_02 3 | 4 | import cv2 5 | import os 6 | import sys 7 | import shutil 8 | 9 | 10 | def extract_video(filePath, outputPath): 11 | if os.path.exists(outputPath): 12 | shutil.rmtree(outputPath) 13 | 14 | os.makedirs(outputPath) 15 | 16 | #Segment the video into frames 17 | cap = cv2.VideoCapture(filePath) 18 | success, count = True, 0 19 | success, image = cap.read() 20 | 21 | while success: 22 | imageFile = os.path.join(outputPath, '{}.jpg'.format(count)) 23 | print(imageFile) 24 | cv2.imwrite(imageFile, image) 25 | count += 1 26 | success, image = cap.read() 27 | 28 | 29 | if __name__ == "__main__": 30 | try: 31 | filePath = sys.argv[1] 32 | outputPath = sys.argv[2] 33 | if (not filePath) or (not outputPath): 34 | raise '' 35 | except: 36 | print('usage: python3 Frame_Generator.py ') 37 | exit(1) 38 | 39 | 40 | extract_video(filePath, outputPath) 41 | 42 | 43 | -------------------------------------------------------------------------------- /tools/Frame_Generator_batch.py: -------------------------------------------------------------------------------- 1 | # 将一个目录下所有的视频依次分解成图像帧, 自动创建同级images目录 2 | # python Frame_Generator_batch.py Test/match1/videos 3 | 4 | import os 5 | import sys 6 | from glob import glob 7 | 8 | from Frame_Generator import extract_video 9 | 10 | # Test/match1/videos/1_05_02.mp4 ----> Test/match1/images/1_05_02/*.jpg 11 | # Test/match1/videos/1_05_02.mp4 ----> Test/match1/images/1_05_02/*.jpg 12 | # ... 13 | 14 | 15 | def extract_videos(videosPath): 16 | for filePath in glob(os.path.join(videosPath, '*mp4')): 17 | tmp, _ = os.path.splitext(filePath) # Test/match1/videos/1_05_02.mp4 ---> Test/match1/videos/1_05_02 18 | imagePath = tmp.replace('videos', 'images') # Test/match1/videos/1_05_02 ---> Test/match1/images/1_05_02 19 | 20 | extract_video(filePath, imagePath) 21 | 22 | 23 | if __name__ == "__main__": 24 | try: 25 | videosPath = sys.argv[1] 26 | if not videosPath: 27 | raise '' 28 | except: 29 | print('usage: python3 Frame_Generator.py ') 30 | exit(1) 31 | 32 | 33 | extract_videos(videosPath) 34 | -------------------------------------------------------------------------------- /tools/Frame_Generator_rally.py: -------------------------------------------------------------------------------- 1 | # 将一个目录下所有的match比赛,分解成图像帧 2 | # python Frame_Generator_rally.py Test ---> Test/match1、Test/match2、Test/match3等 3 | 4 | import os 5 | import sys 6 | 7 | from Frame_Generator_batch import extract_videos 8 | 9 | 10 | def extract_rally(rallyPath): 11 | for dir in os.listdir(rallyPath): 12 | video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/videos 13 | 14 | print(video_path) 15 | extract_videos(video_path) 16 | 17 | 18 | 19 | if __name__ == "__main__": 20 | try: 21 | rallyPath = sys.argv[1] 22 | if not rallyPath : 23 | raise '' 24 | except: 25 | print('usage: python3 handle_dataset.py ') 26 | exit(1) 27 | 28 | 29 | extract_rally(rallyPath) 30 | -------------------------------------------------------------------------------- /tools/check_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import pandas as pd 5 | import numpy as np 6 | 7 | # 基本的数据校验 8 | 9 | base_path = "/home/chg/Documents/Badminton/merge_dataset/TrackNetV2/" 10 | 11 | 12 | # Amateur/match1 13 | def handle_rally(match_path): 14 | images_path = "{}/images".format(match_path) 15 | labels_path = "{}/labels".format(match_path) 16 | 17 | # 校验json标签 18 | for images_dir in os.listdir(images_path): 19 | for js_dir in glob.glob(os.path.join(images_path+"/"+images_dir, '*.json')): 20 | print(js_dir) 21 | 22 | with open(js_dir, 'r') as file: 23 | data = json.load(file) 24 | 25 | for shape in data["shapes"]: 26 | index = shape["label"] 27 | 28 | try: 29 | index = int(index) 30 | except ValueError: 31 | print("Error index: {}".format(index)) 32 | exit() 33 | 34 | if index<=0 or index >=33: 35 | print("Error index: {}".format(index)) 36 | exit() 37 | 38 | # 校验csv标签 39 | # for labels_dir in os.listdir(labels_path): 40 | for csv_dir in glob.glob(os.path.join(labels_path+"/", '*.csv')): 41 | images_base_path = csv_dir.replace("labels", "images").split(".")[0] 42 | 43 | print(csv_dir) 44 | df = pd.read_csv(csv_dir) 45 | frame_nums = df["frame_num"].values 46 | 47 | assert(np.all(np.ediff1d(frame_nums) == 1)) 48 | 49 | for frame_num in frame_nums: 50 | if not os.path.exists("{}/{}.jpg".format(images_base_path, frame_num)): 51 | print("{}/{}.jpg".format(images_base_path, frame_num)) 52 | exit() 53 | 54 | # Amateur 55 | def handle_rally_batch(batch_path): 56 | for rally_dir in os.listdir(batch_path): 57 | 58 | match_path = os.path.join(batch_path, rally_dir) 59 | handle_rally(match_path) 60 | 61 | # from_path: 62 | def handle_base_path(base_path): 63 | for to_batch_dir in os.listdir(base_path): 64 | to_batch_path = os.path.join(base_path, to_batch_dir) 65 | 66 | handle_rally_batch(to_batch_path) 67 | 68 | 69 | handle_base_path(base_path) -------------------------------------------------------------------------------- /tools/handle_Darklabel.py: -------------------------------------------------------------------------------- 1 | 2 | # match1/csv/xxx_ball.csv ---> match1/labels/xxx.csv 3 | # match1/video ---> match1/videos 4 | 5 | # Frame,Visibility,X,Y ---> frame_num,visible,x,y 6 | # 11,1,621,305 ---> 11,1,0.48515625,0.423611111 7 | 8 | 9 | import os 10 | import sys 11 | import glob 12 | import cv2 13 | import pandas as pd 14 | 15 | try: 16 | rallyPath = sys.argv[1] 17 | if not rallyPath : 18 | raise '' 19 | except: 20 | print('usage: python3 handle_dataset.py ') 21 | exit(1) 22 | 23 | 24 | for dir in os.listdir(rallyPath): 25 | 26 | # video_path = os.path.join(rallyPath, dir, "video") # Test/match1/video 27 | # csv_path = os.path.join(rallyPath, dir, "csv") # Test/match1/csv 28 | 29 | # assert os.path.isdir(video_path), '这不是一个目录' 30 | # assert os.path.isdir(csv_path), '这不是一个目录' 31 | 32 | # # 目录更名, video->videos, csv->labels 33 | # new_video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/videos 34 | # new_csv_path = os.path.join(rallyPath, dir, "labels") # Test/match1/labels 35 | 36 | # os.rename(video_path, new_video_path) 37 | # os.rename(csv_path, new_csv_path) 38 | 39 | new_video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/video 40 | new_csv_path = os.path.join(rallyPath, dir, "labels") # Test/match1/csv 41 | 42 | # Test/match1/videos/1_05_02.mp4 43 | # Test/match1/labels/1_05_02.csv 44 | for file in os.listdir(new_csv_path): 45 | filename, _ = os.path.splitext(file) 46 | 47 | file_path = os.path.join(new_csv_path, file) # Test/match1/labels/1_07_03.csv 48 | video_path = os.path.join(new_video_path, '{}.mp4'.format(filename)) # Test/match1/videos/1_07_03.mp4 49 | 50 | print("handle video: {}".format(video_path)) 51 | 52 | if os.path.exists(video_path): 53 | # 确定视频宽高 54 | cap = cv2.VideoCapture(video_path) 55 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 56 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 57 | cap.release() 58 | else: 59 | frame_width, frame_height = 1280, 720 60 | 61 | # 将坐标转化为比例 62 | df = pd.read_csv(file_path) 63 | df.columns = ['frame_num', 'visible', 'x', 'y'] 64 | # df = df.rename(columns={'Frame': 'frame_num', 'Visibility': 'visible', 'X': 'x', 'Y': 'y'}) 65 | 66 | df['x'] = df['x'].astype(float) 67 | df['y'] = df['y'].astype(float) 68 | 69 | df.loc[:, 'x'] /= frame_width 70 | df.loc[:, 'y'] /= frame_height 71 | 72 | df.to_csv(file_path, index=False) 73 | -------------------------------------------------------------------------------- /tools/handle_tracknet_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | # 预处理tracknet数据集 https://hackmd.io/Nf8Rh1NrSrqNUzmO0sQKZw 3 | 4 | # match1/csv/xxx_ball.csv ---> match1/labels/xxx.csv 5 | # match1/video ---> match1/videos 6 | 7 | # Frame,Visibility,X,Y ---> frame_num,visible,x,y 8 | # 11,1,621,305 ---> 11,1,0.48515625,0.423611111 9 | 10 | 11 | # 输入参数 12 | # TrackNetV2/Professional 13 | # TrackNetV2/Amateur 14 | # TrackNetV2/Test 15 | 16 | import os 17 | import sys 18 | import glob 19 | import cv2 20 | import pandas as pd 21 | 22 | try: 23 | rallyPath = sys.argv[1] 24 | if not rallyPath : 25 | raise '' 26 | except: 27 | print('usage: python3 handle_dataset.py ') 28 | exit(1) 29 | 30 | 31 | for dir in os.listdir(rallyPath): 32 | 33 | video_path = os.path.join(rallyPath, dir, "video") # Test/match1/video 34 | csv_path = os.path.join(rallyPath, dir, "csv") # Test/match1/csv 35 | 36 | assert os.path.isdir(video_path), '这不是一个目录' 37 | assert os.path.isdir(csv_path), '这不是一个目录' 38 | 39 | # csv更名,去掉'_ball' 40 | for file in glob.glob(os.path.join(csv_path, "*.csv")): 41 | new_name = file.replace('_ball', '') 42 | os.rename(file, new_name) 43 | 44 | # 目录更名, video->videos, csv->labels 45 | new_video_path = os.path.join(rallyPath, dir, "videos") # Test/match1/videos 46 | new_csv_path = os.path.join(rallyPath, dir, "labels") # Test/match1/labels 47 | 48 | os.rename(video_path, new_video_path) 49 | os.rename(csv_path, new_csv_path) 50 | 51 | 52 | # Test/match1/videos/1_05_02.mp4 53 | # Test/match1/labels/1_05_02.csv 54 | for file in os.listdir(new_csv_path): 55 | filename, _ = os.path.splitext(file) 56 | 57 | file_path = os.path.join(new_csv_path, file) # Test/match1/labels/1_07_03.csv 58 | video_path = os.path.join(new_video_path, '{}.mp4'.format(filename)) # Test/match1/videos/1_07_03.mp4 59 | 60 | # 确定视频宽高 61 | cap = cv2.VideoCapture(video_path) 62 | frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 63 | frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 64 | cap.release() 65 | 66 | # 将坐标转化为比例 67 | df = pd.read_csv(file_path) 68 | df = df.rename(columns={'Frame': 'frame_num', 'Visibility': 'visible', 'X': 'x', 'Y': 'y'}) 69 | df['x'] = df['x'].astype(float) 70 | df['y'] = df['y'].astype(float) 71 | 72 | df.loc[:, 'x'] /= frame_width 73 | df.loc[:, 'y'] /= frame_height 74 | 75 | df.to_csv(file_path, index=False) 76 | -------------------------------------------------------------------------------- /tools/label_tool.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import cv2 3 | from enum import Enum 4 | import pandas as pd 5 | import numpy as np 6 | import os 7 | import argparse 8 | import sys 9 | import time 10 | 11 | # python label_tool.py ./dataset/match1/videos/1_10_12.mp4 --csv_dir dataset/match/labels 12 | # 若不加csv_dir,则会默认从mp4文件同级目录读取csv 13 | 14 | 15 | # state 0:hidden 1:visible 16 | state_name = ['HIDDEN', 'VISIBLE'] 17 | 18 | keybindings = { 19 | 'next': [ ord('n') ], 20 | 'prev': [ ord('p')], 21 | 22 | 'piece_start': [ ord('s'), ], # 裁剪开始帧 23 | 'piece_end': [ ord('e'), ], # 裁剪结束帧 24 | 25 | 'first_frame': [ ord('z'), ], 26 | 'last_frame': [ ord('x'), ], 27 | 28 | 'forward_frames': [ ord('f'), ], # 前进36帧 29 | 'backward_frames': [ ord('b'), ], # 后退36帧 30 | 31 | 'circle_grow': [ ord('='), ord('+') ], 32 | 'circle_shrink': [ ord('-'), ], 33 | 34 | 'quit': [ ord('q'), ], 35 | } 36 | 37 | 38 | 39 | class VideoPlayer(): 40 | def __init__(self, opt) -> None: 41 | self.jump = 36 42 | 43 | self.cap = cv2.VideoCapture(opt.video_path) 44 | self.width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 45 | self.height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 46 | self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 47 | #self.fourcc = int(self.cap.get(cv.CAP_PROP_FOURCC)) 48 | # self.fourcc = cv2.VideoWriter_fourcc('H', 'E', 'V', 'C') 49 | self.fourcc = cv2.VideoWriter_fourcc(*'mp4v') 50 | # self.fourcc = cv2.VideoWriter_fourcc(*'XVID') 51 | self.fps = self.cap.get(cv2.CAP_PROP_FPS) 52 | self.bitrate = self.cap.get(cv2.CAP_PROP_BITRATE) 53 | 54 | # Check video lens! 55 | # ret, frame = self.cap.read() 56 | # frame_cnt = 1 57 | # while ret: 58 | # ret, frame = self.cap.read() 59 | # if ret: 60 | # frame_cnt += 1 61 | # else: 62 | # print("Waringing: {} frame decode error!".format(frame_cnt)) 63 | 64 | # if self.frames != frame_cnt: 65 | # print(" self.frames: {}, frame_cnt: {} ".format(self.frames, frame_cnt)) 66 | # self.frames = frame_cnt 67 | 68 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) 69 | 70 | 71 | self.video_path = Path(opt.video_path) 72 | self.circle_size = 5 73 | if opt.csv_dir is None: 74 | self.csv_path = self.video_path.with_suffix('.csv') 75 | else: 76 | self.csv_path = Path(opt.csv_dir) / Path(self.video_path.stem).with_suffix('.csv') 77 | 78 | self.window = cv2.namedWindow('Frame', cv2.WINDOW_NORMAL) 79 | cv2.resizeWindow('Frame', 1280, 720) 80 | 81 | _, self.frame = self.cap.read() 82 | self.frame_num = 0 83 | 84 | self.piece_start = 0 85 | self.piece_end = 0 86 | 87 | if os.path.exists(self.csv_path): 88 | self.info = pd.read_csv(self.csv_path) 89 | 90 | if len(self.info.index) != self.frames: 91 | print("pd len: {}, camera len: {}".format(len(self.info.index), self.frames)) 92 | print("Number of frames in video and dictionary are not the same!") 93 | print("Fail to load!") 94 | exit(1) 95 | 96 | 97 | # self.info = {'frame_num':[], 'visible':[], 'x':[], 'y':[]} 98 | 99 | # for idx in range(self.frames): 100 | # self.info['frame_num'].append(idx) 101 | # self.info['visible'].append(0) 102 | # self.info['x'].append(0) 103 | # self.info['y'].append(0) 104 | 105 | # print("pandas dataframe len: {}".format(len(self.info))) 106 | 107 | else: 108 | self.info = {k: list(v.values()) for k, v in self.info.to_dict().items()} 109 | print("Load labeled {} successfully.".format(self.csv_path)) 110 | else: 111 | print("Create new dictionary") 112 | 113 | self.info = {'frame_num':[], 'visible':[], 'x':[], 'y':[]} 114 | 115 | for idx in range(self.frames): 116 | self.info['frame_num'].append(idx) 117 | self.info['visible'].append(0) 118 | self.info['x'].append(0) 119 | self.info['y'].append(0) 120 | 121 | print("pandas dataframe len: {}".format(len(self.info))) 122 | 123 | cv2.setMouseCallback('Frame',self.markBall) 124 | self.display() 125 | 126 | 127 | def save_piece(self): 128 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.piece_start) 129 | 130 | out = cv2.VideoWriter('{}_{}.mp4'.format(self.piece_start+1, self.piece_end+1), self.fourcc, self.fps, (self.width, self.height)) 131 | 132 | frame_cnt = self.piece_start 133 | while frame_cnt <= self.piece_end: 134 | ret, frame = self.cap.read() 135 | out.write(frame) 136 | 137 | frame_cnt += 1 138 | 139 | out.release() 140 | 141 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num) 142 | print("save piece succefully!") 143 | 144 | def markBall(self, event, x, y, flags, param): 145 | x /= self.width 146 | y /= self.height 147 | if event == cv2.EVENT_LBUTTONDOWN: 148 | self.info['frame_num'][self.frame_num] = self.frame_num 149 | self.info['x'][self.frame_num] = x 150 | self.info['y'][self.frame_num] = y 151 | self.info['visible'][self.frame_num] = 1 152 | 153 | elif event == cv2.EVENT_RBUTTONDBLCLK: 154 | self.info['frame_num'][self.frame_num] = self.frame_num 155 | self.info['x'][self.frame_num] = 0 156 | self.info['y'][self.frame_num] = 0 157 | self.info['visible'][self.frame_num] = 0 158 | 159 | 160 | def display(self): 161 | res_frame = self.frame.copy() 162 | res_frame = cv2.putText(res_frame, state_name[self.info['visible'][self.frame_num]], (100, 110), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 2, cv2.LINE_AA) 163 | res_frame = cv2.putText(res_frame, "Frame: {}, Total: {}".format(int(self.frame_num+1), int(self.frames)), (100, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2, cv2.LINE_AA) 164 | res_frame = cv2.putText(res_frame, "Piece: {}-{}".format(int(self.piece_start+1), int(self.piece_end+1)), (100, 170), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 2, cv2.LINE_AA) 165 | 166 | if self.info['visible'][self.frame_num]: 167 | x = int(self.info['x'][self.frame_num] * self.width) 168 | y = int(self.info['y'][self.frame_num] * self.height) 169 | cv2.circle(res_frame, (x, y), self.circle_size, (0, 0, 255), -1) 170 | 171 | cv2.imshow('Frame', res_frame) 172 | 173 | #print("frame num: {}".format(self.frame_num)) 174 | #print(type(self.frame)) 175 | #assert(ret==True) 176 | # frame_num 0---->frames-1 177 | def main_loop(self): 178 | key = cv2.waitKeyEx(1) 179 | if key in keybindings['first_frame']: 180 | self.frame_num = 0 181 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num) 182 | ret, self.frame = self.cap.read() 183 | 184 | assert(ret==True) 185 | 186 | elif key in keybindings['last_frame']: 187 | self.frame_num = self.frames-1 188 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num) # cap.set时, frame_num不用加1 189 | ret, self.frame = self.cap.read() 190 | 191 | print(type(self.frame)) 192 | assert(ret==True) 193 | 194 | 195 | elif key in keybindings['next']: 196 | if self.frame_num < self.frames-1: 197 | ret, self.frame = self.cap.read() 198 | self.frame_num += 1 199 | 200 | assert(ret==True) 201 | 202 | elif key in keybindings['prev']: 203 | time.sleep(0.01) 204 | if self.frame_num > 0: 205 | self.frame_num -= 1 206 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num) 207 | _, self.frame = self.cap.read() 208 | 209 | 210 | elif key in keybindings['forward_frames']: 211 | if self.frame_num < self.frames-1: 212 | for _ in range(self.jump): 213 | if self.frame_num == self.frames-2: # 倒数第二帧,最后一帧使用read() 214 | break 215 | 216 | self.cap.grab() # cap.grab跳过帧, frame_num加1 217 | self.frame_num += 1 218 | 219 | ret, self.frame = self.cap.read() 220 | self.frame_num += 1 221 | 222 | 223 | elif key in keybindings['backward_frames']: 224 | if self.frame_num < self.jump: 225 | self.frame_num = 0 226 | else: 227 | self.frame_num -= self.jump 228 | 229 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.frame_num) 230 | ret, self.frame = self.cap.read() 231 | 232 | assert(ret==True) 233 | 234 | 235 | elif key in keybindings['circle_grow']: 236 | self.circle_size += 1 237 | elif key in keybindings['circle_shrink']: 238 | self.circle_size -= 1 239 | 240 | 241 | elif key in keybindings['piece_start']: 242 | self.piece_start = self.frame_num 243 | 244 | elif key in keybindings['piece_end']: 245 | self.piece_end = self.frame_num 246 | self.save_piece() 247 | 248 | 249 | elif key in keybindings['quit']: 250 | self.finish() 251 | return 252 | 253 | 254 | self.display() 255 | 256 | 257 | def finish(self): 258 | self.cap.release() 259 | cv2.destroyAllWindows() 260 | df = pd.DataFrame.from_dict(self.info).sort_values(by=['frame_num'], ignore_index=True) 261 | df.to_csv(self.csv_path, index=False) 262 | 263 | 264 | def __del__(self): 265 | self.finish() 266 | 267 | 268 | def parse_opt(): 269 | parser = argparse.ArgumentParser() 270 | parser.add_argument('video_path', type=str, nargs='?', default=None, help='Path to the video file.') 271 | parser.add_argument('--csv_dir', type=str, default=None, help='Path to the directory where csv file should be saved. If not specified, csv file will be saved in the same directory as the video file.') 272 | parser.add_argument('--remove_duplicate_frames', type=bool, default=False, help='Should identical consecutie frames be reduces to one frame.') 273 | opt = parser.parse_args() 274 | return opt 275 | 276 | 277 | def remove_duplicate_frames(video_path, output_path): 278 | # Open the video file 279 | vid = cv2.VideoCapture(video_path) 280 | 281 | # Set the frame width and height 282 | frame_width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)) 283 | frame_height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT)) 284 | fps = int(vid.get(cv2.CAP_PROP_FPS)) 285 | 286 | # Create a VideoWriter object for the output video file 287 | out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height)) 288 | 289 | # Read and process the frames one by one 290 | previous_frame = None 291 | while True: 292 | # Read the next frame 293 | success, frame = vid.read() 294 | 295 | # If we reached the end of the video, break the loop 296 | if not success: 297 | break 298 | 299 | # If the current frame is not a duplicate, write it to the output video 300 | if previous_frame is None or cv2.PSNR(frame, previous_frame) < 40.: 301 | out.write(frame) 302 | 303 | # Update the previous frame 304 | previous_frame = frame 305 | print('finished removing duplicates') 306 | 307 | 308 | 309 | if __name__ == '__main__': 310 | opt = parse_opt() 311 | 312 | if opt.video_path is None: 313 | if getattr(sys, 'frozen', False): 314 | application_path = os.path.dirname(sys.executable) 315 | elif __file__: 316 | application_path = os.path.dirname(__file__) 317 | p = Path(application_path) 318 | video_path = next(p.glob('*.mp4')) 319 | toRemove = input('Should duplicated, consecutive frames be deleted? Insert "y" or "n": \n') 320 | if toRemove == 'y': 321 | bez_duplikatow_video_path = str(video_path.with_stem(video_path.stem + '_no_dups')) 322 | remove_duplicate_frames(str(video_path), bez_duplikatow_video_path) 323 | video_path = bez_duplikatow_video_path 324 | opt.video_path = str(video_path) 325 | 326 | # run as a CLI script 327 | elif opt.remove_duplicate_frames == True: 328 | remove_duplicate_frames(opt.video_path, opt.video_path) 329 | 330 | player = VideoPlayer(opt) 331 | while(player.cap.isOpened()): 332 | player.main_loop() 333 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import os 5 | import sys 6 | import numpy as np 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | from argparse import ArgumentParser 10 | from tensorboardX import SummaryWriter 11 | 12 | from models.tracknet import TrackNet 13 | from utils.dataloaders import create_dataloader 14 | from utils.general import check_dataset, outcome, evaluation, tensorboard_log 15 | 16 | 17 | # from yolov5 detect.py 18 | FILE = Path(__file__).resolve() 19 | ABS_ROOT = FILE.parents[0] # YOLOv5 root directory 20 | if str(ABS_ROOT) not in sys.path: 21 | sys.path.append(str(ABS_ROOT)) # add ROOT to PATH 22 | ROOT = Path(os.path.relpath(ABS_ROOT, Path.cwd())) # relative 23 | 24 | 25 | def wbce_loss(y_true, y_pred): 26 | return -1*( 27 | ((1-y_pred)**2) * y_true * torch.log(torch.clamp(y_pred, min=1e-07, max=1)) + 28 | (y_pred**2) * (1-y_true) * torch.log(torch.clamp(1-y_pred, min=1e-07, max=1)) 29 | ).sum() 30 | 31 | 32 | def validation_loop(device, model, val_loader, log_writer, epoch): 33 | model.eval() 34 | 35 | loss_sum = 0 36 | TP = TN = FP1 = FP2 = FN = 0 37 | 38 | with torch.inference_mode(): 39 | pbar = tqdm(val_loader, ncols=180) 40 | for batch_index, (X, y) in enumerate(pbar): 41 | X, y = X.to(device), y.to(device) 42 | y_pred = model(X) 43 | 44 | loss_sum += wbce_loss(y, y_pred).item() 45 | 46 | y_ = y.detach().cpu().numpy() 47 | y_pred_ = y_pred.detach().cpu().numpy() 48 | 49 | y_pred_ = (y_pred_ > 0.5).astype('float32') 50 | (tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_) 51 | TP += tp 52 | TN += tn 53 | FP1 += fp1 54 | FP2 += fp2 55 | FN += fn 56 | 57 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN) 58 | 59 | pbar.set_description('Val loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \ 60 | loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall)) 61 | 62 | tensorboard_log(log_writer, "Val", loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, epoch) 63 | 64 | return loss_sum/len(val_loader) 65 | 66 | 67 | def training_loop(device, model, optimizer, lr_scheduler, train_loader, val_loader, start_epoch, epochs, save_dir): 68 | best_val_loss = float('inf') 69 | 70 | checkpoint_period = 3 71 | log_period = 100 72 | 73 | log_dir = '{}/logs'.format(save_dir) 74 | if not os.path.exists(log_dir): 75 | os.makedirs(log_dir) 76 | 77 | log_writer = SummaryWriter(log_dir) 78 | 79 | for epoch in range(start_epoch, epochs): 80 | print("\n==================================================================================================") 81 | tqdm.write("Epoch: {} / {}\n".format(epoch, epochs)) 82 | running_loss = 0.0 83 | TP = TN = FP1 = FP2 = FN = 0 84 | 85 | model.train() 86 | pbar = tqdm(train_loader, ncols=180) 87 | for batch_index, (X, y) in enumerate(pbar): 88 | X, y = X.to(device), y.to(device) 89 | optimizer.zero_grad() 90 | 91 | y_pred = model(X) 92 | 93 | loss = wbce_loss(y, y_pred) 94 | loss.backward() 95 | 96 | optimizer.step() 97 | 98 | running_loss += loss.item() 99 | 100 | y_ = y.detach().cpu().numpy() 101 | y_pred_ = y_pred.detach().cpu().numpy() 102 | 103 | y_pred_ = (y_pred_ > 0.5).astype('float32') 104 | (tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_) 105 | TP += tp 106 | TN += tn 107 | FP1 += fp1 108 | FP2 += fp2 109 | FN += fn 110 | 111 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN) 112 | 113 | pbar.set_description('Train loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \ 114 | running_loss / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall)) 115 | 116 | if batch_index % log_period == 0: 117 | with torch.inference_mode(): 118 | images = [ 119 | torch.unsqueeze(y[0,0,:,:], 0).repeat(3,1,1).cpu(), 120 | torch.unsqueeze(y_pred[0,0,:,:], 0).repeat(3,1,1).cpu(), 121 | ] 122 | 123 | images.append(X[0,(0,1,2),:,:].cpu()) 124 | res = X[0, (0,1,2),:,:] * y[0,0,:,:] 125 | 126 | images.append(res.cpu()) 127 | grid = torchvision.utils.make_grid(images, nrow=1) 128 | 129 | torchvision.utils.save_image(grid, '{}/epoch_{}_batch{}.png'.format(log_dir, epoch, batch_index)) 130 | 131 | if val_loader is not None: 132 | best = False 133 | val_loss = validation_loop(device, model, val_loader, log_writer, epoch) 134 | if val_loss < best_val_loss: 135 | best_val_loss = val_loss 136 | best = True 137 | 138 | if epoch % checkpoint_period == checkpoint_period - 1: 139 | tqdm.write('\n--- Saving weights to: {}/last.pt ---'.format(save_dir)) 140 | torch.save(model.state_dict(), '{}/last.pt'.format(save_dir)) 141 | 142 | if best: 143 | tqdm.write('--- Saving weights to: {}/best.pt ---'.format(save_dir)) 144 | torch.save(model.state_dict(), '{}/best.pt'.format(save_dir)) 145 | 146 | tensorboard_log(log_writer, "Train", running_loss / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, epoch) 147 | 148 | print('lr: {}'.format(lr_scheduler.get_last_lr())) 149 | lr_scheduler.step() 150 | 151 | if epoch%10 == 0: 152 | ckpt_dir = '{}/checkpoint'.format(save_dir) 153 | if not os.path.exists(ckpt_dir): 154 | os.makedirs(ckpt_dir) 155 | 156 | cur_ckpt = '{}/{}/ckpt_{}.pt'.format(ABS_ROOT, ckpt_dir, epoch) 157 | latest_ckpt = '{}/{}/ckpt_latest.pt'.format(ABS_ROOT, ckpt_dir) 158 | 159 | checkpoint = { 160 | "net": model.state_dict(), 161 | 'optimizer': optimizer.state_dict(), 162 | "epoch": epoch+1, 163 | 'lr_scheduler': lr_scheduler.state_dict() 164 | } 165 | 166 | torch.save(checkpoint, cur_ckpt) 167 | os.system('ln -sf {} {}'.format(cur_ckpt, latest_ckpt)) 168 | print("save checkpoint {}".format(cur_ckpt)) 169 | 170 | 171 | def parse_opt(): 172 | parser = ArgumentParser() 173 | 174 | parser.add_argument('--data', type=str, default=ROOT / 'data/match/test.yaml', help='Path to dataset.') 175 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.') 176 | parser.add_argument('--epochs', type=int, default=100, help='total training epochs') 177 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w') 178 | parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch') 179 | parser.add_argument('--project', default=ROOT / 'runs/train', help='save results to project/name') 180 | parser.add_argument('--resume', action='store_true', help='whether load checkpoint for resume') 181 | 182 | opt = parser.parse_args() 183 | return opt 184 | 185 | 186 | def main(opt): 187 | d_save_dir = str(opt.project) 188 | f_weights = str(opt.weights) 189 | epochs = opt.epochs 190 | batch_size = opt.batch_size 191 | f_data = str(opt.data) 192 | imgsz = opt.imgsz 193 | 194 | start_epoch = 0 195 | 196 | data_dict = check_dataset(f_data) 197 | train_path, val_path = data_dict['train'], data_dict['val'] 198 | 199 | if not os.path.exists(d_save_dir): 200 | os.makedirs(d_save_dir) 201 | 202 | 203 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 204 | model = TrackNet().to(device) 205 | 206 | optimizer = torch.optim.Adadelta(model.parameters(), lr=0.99) 207 | 208 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=epochs) 209 | lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.9) 210 | 211 | if opt.resume: 212 | default_ckpt = "{}/checkpoint/ckpt_latest.pt".format(d_save_dir) 213 | checkpoint = torch.load(default_ckpt) 214 | 215 | model.load_state_dict(checkpoint['net']) 216 | optimizer.load_state_dict(checkpoint['optimizer']) 217 | start_epoch = checkpoint['epoch'] 218 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 219 | 220 | print("load checkpoint {}".format(default_ckpt)) 221 | else: 222 | if os.path.exists(f_weights): 223 | print("load pretrain weights {}".format(f_weights)) 224 | model.load_state_dict(torch.load(f_weights)) 225 | else: 226 | print("train from scratch") 227 | 228 | train_loader = create_dataloader(train_path, imgsz, batch_size=batch_size, augment=True, shuffle=True) # augment in training 229 | val_loader = create_dataloader(val_path, imgsz, batch_size=batch_size) 230 | 231 | 232 | training_loop(device, model, optimizer, lr_scheduler, train_loader, val_loader, start_epoch, epochs, d_save_dir) 233 | 234 | 235 | 236 | if __name__ == '__main__': 237 | opt = parse_opt() 238 | main(opt) -------------------------------------------------------------------------------- /utils/augmentations.py: -------------------------------------------------------------------------------- 1 | # YOLOv5 🚀 by Ultralytics, AGPL-3.0 license 2 | """Image augmentation functions.""" 3 | 4 | import math 5 | import random 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | import torchvision.transforms as T 11 | import torchvision.transforms.functional as TF 12 | import albumentations as A 13 | 14 | 15 | class Albumentations: 16 | # YOLOv5 Albumentations class (optional, only used if package is installed) 17 | def __init__(self, imgsz=[288, 512]): 18 | self.transform = None 19 | prefix = "albumentations: " 20 | 21 | T = [ 22 | A.RandomResizedCrop(height=imgsz[0], width=imgsz[1], scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0), 23 | A.Blur(p=0.1), 24 | A.MedianBlur(p=0.1), 25 | A.ToGray(p=0.1), 26 | A.CLAHE(p=0.1), 27 | A.RandomBrightnessContrast(p=0.0), 28 | A.RandomGamma(p=0.0), 29 | A.ImageCompression(quality_lower=75, p=0.0), 30 | ] # transforms 31 | self.transform = A.Compose(T, keypoint_params=A.KeypointParams(format="xy", remove_invisible=False)) 32 | 33 | print(prefix + ", ".join(f"{x}".replace("always_apply=False, ", "") for x in T if x.p)) 34 | 35 | 36 | def __call__(self, im, kps, p=1.0): 37 | if self.transform and random.random() < p: 38 | new = self.transform(image=im, keypoints=kps) # transformed 39 | im, kps = new["image"], np.array(new["keypoints"]) 40 | 41 | return im, kps 42 | 43 | 44 | def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5): 45 | # HSV color-space augmentation 46 | if hgain or sgain or vgain: 47 | r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains 48 | hue, sat, val = cv2.split(cv2.cvtColor(im, cv2.COLOR_BGR2HSV)) 49 | dtype = im.dtype # uint8 50 | 51 | x = np.arange(0, 256, dtype=r.dtype) 52 | lut_hue = ((x * r[0]) % 180).astype(dtype) 53 | lut_sat = np.clip(x * r[1], 0, 255).astype(dtype) 54 | lut_val = np.clip(x * r[2], 0, 255).astype(dtype) 55 | 56 | im_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))) 57 | cv2.cvtColor(im_hsv, cv2.COLOR_HSV2BGR, dst=im) # no return needed 58 | 59 | 60 | def random_perspective( 61 | im, kps, degrees=10, translate=0.1, scale=0.1, shear=10, perspective=0.0, border=(0, 0) 62 | ): 63 | # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10)) 64 | 65 | height = im.shape[0] + border[0] * 2 # shape(h,w,c) 66 | width = im.shape[1] + border[1] * 2 67 | 68 | # Center 69 | C = np.eye(3) 70 | C[0, 2] = -im.shape[1] / 2 # x translation (pixels) 71 | C[1, 2] = -im.shape[0] / 2 # y translation (pixels) 72 | 73 | # Perspective 74 | P = np.eye(3) 75 | P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y) 76 | P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x) 77 | 78 | # Rotation and Scale 79 | R = np.eye(3) 80 | a = random.uniform(-degrees, degrees) 81 | # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations 82 | s = random.uniform(1 - scale, 1 + scale) 83 | # s = 2 ** random.uniform(-scale, scale) 84 | R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) 85 | 86 | # Shear 87 | S = np.eye(3) 88 | S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg) 89 | S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg) 90 | 91 | # Translation 92 | T = np.eye(3) 93 | T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels) 94 | T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels) 95 | 96 | # Combined rotation matrix 97 | M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT 98 | if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed 99 | if perspective: 100 | im = cv2.warpPerspective(im, M, dsize=(width, height), borderValue=(114, 114, 114)) 101 | else: # affine 102 | im = cv2.warpAffine(im, M[:2], dsize=(width, height), borderValue=(114, 114, 114)) 103 | 104 | # Visualize 105 | # import matplotlib.pyplot as plt 106 | # ax = plt.subplots(1, 2, figsize=(12, 6))[1].ravel() 107 | # ax[0].imshow(im[:, :, ::-1]) # base 108 | # ax[1].imshow(im2[:, :, ::-1]) # warped 109 | # plt.show() 110 | 111 | # Transform label coordinates 112 | n = len(kps) 113 | 114 | xy = np.ones((n, 3)) 115 | xy[:, :2] = kps # 齐次化坐标 116 | 117 | xy = xy @ M.T # transform 118 | xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]) # perspective rescale or affine 119 | 120 | # clip 121 | xy[:, 0] = xy[:, 0].clip(0, width) 122 | xy[:, 1] = xy[:, 1].clip(0, height) 123 | 124 | return im, xy 125 | 126 | 127 | def random_flip(im, kps, p=0.5): 128 | # Flip left-right 129 | if random.random() < p: 130 | im = np.fliplr(im) 131 | kps[:, 0] = im.shape[1] - kps[:, 0] # width 132 | 133 | return im, kps 134 | 135 | -------------------------------------------------------------------------------- /utils/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | 6 | import os 7 | import time 8 | import random 9 | import cv2 10 | import pandas as pd 11 | import numpy as np 12 | 13 | from utils.augmentations import random_perspective, Albumentations, augment_hsv, random_flip 14 | 15 | 16 | class ToTensor: 17 | # YOLOv5 ToTensor class for image preprocessing 18 | def __init__(self, half=False): 19 | super().__init__() 20 | self.half = half 21 | 22 | def __call__(self, im): # im = np.array HWC in BGR order 23 | im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous 24 | im = torch.from_numpy(im) # to torch 25 | im = im.half() if self.half else im.float() # uint8 to fp16/32 26 | im /= 255.0 # 0-255 to 0.0-1.0 27 | return im 28 | 29 | 30 | # num_workers https://zhuanlan.zhihu.com/p/568076554 31 | # batch size 20 nw 1 ---> 1.35 nw 8 ---> 1.75 32 | def create_dataloader(path, 33 | imgsz=[288, 512], 34 | batch_size=1, 35 | sq=3, 36 | augment=False, 37 | workers=8, 38 | shuffle=False): 39 | print("create dataloader image size: {}".format(imgsz)) 40 | 41 | dataset = LoadImagesAndLabels(path, imgsz, batch_size, sq, augment) 42 | 43 | batch_size = min(batch_size, len(dataset)) 44 | nd = torch.cuda.device_count() # number of CUDA devices 45 | nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers 46 | # print("num_workers: {}".format(nw)) 47 | 48 | return DataLoader(dataset, batch_size=batch_size, num_workers=nw, shuffle=shuffle) 49 | 50 | # assert match/images exist, return path/images/xxx list 51 | def get_match_image_list(match_path): 52 | base_images = os.path.join(match_path, "images") 53 | assert os.path.exists(base_images), base_images+" is invalid" 54 | 55 | image_dir_list = [f.path for f in os.scandir(base_images) if f.is_dir()] 56 | return image_dir_list 57 | 58 | # return all match image list 59 | def get_rally_image_list(rally_path): 60 | image_dir_list = [] 61 | 62 | for match_name in os.listdir(rally_path): 63 | image_dir_list.extend(get_match_image_list(os.path.join(rally_path, match_name))) 64 | 65 | return image_dir_list 66 | 67 | class LoadImagesAndLabels(Dataset): 68 | def __init__(self, 69 | path, 70 | imgsz=[288, 512], 71 | batch_size=1, 72 | sq=3,# 网络输入几张图 73 | augment=False, 74 | ): 75 | self.imgsz = imgsz 76 | self.path = path 77 | self.batch_size = batch_size 78 | self.sq = sq 79 | self.augment = augment 80 | self.albumentations = Albumentations(imgsz) if augment else None 81 | 82 | self.image_dir_list = [] # 所有的图片目录 83 | self.label_path_list = [] # 所有的样本路径,文件, 与image_dir_list的元素一一对应 84 | self.df_list = [] # labels 85 | self.lens = [] # 总共的样本数量 86 | 87 | 88 | # image : "./dataset1/match2/images/1_0_1" 89 | # image list : ["./dataset1/match2/images/1_0_1", "./dataset1/match2/images/1_1_1"] 90 | 91 | # match : "./dataset1/match2" 92 | # match list : ["./dataset1/match1", "./dataset1/match2"] 93 | 94 | # rally : "Professional" 95 | # rally list : ["Amateur", "Professional"] 96 | 97 | for p in path if isinstance(path, list) else [path]: 98 | if p[-1] == '/': 99 | p = p[:-1] 100 | 101 | if "images" in p: # image 102 | self.image_dir_list.append(p) 103 | elif "match" in p: # match 104 | self.image_dir_list.extend(get_match_image_list(p)) 105 | else: # rally 106 | self.image_dir_list.extend(get_rally_image_list(p)) 107 | 108 | 109 | # 校验csv标签长度和img目录文件数量是否一致 110 | 111 | print("\n") 112 | print(self.image_dir_list) 113 | 114 | # TODO:::::::: 115 | # Check cache 116 | 117 | # 读取csv 118 | print("\n") 119 | for csv_path in self.image_dir_list: 120 | label_path = "{}.{}".format(csv_path.replace('images', 'labels'), "csv") 121 | self.label_path_list.append(label_path) 122 | 123 | df = pd.read_csv(label_path) 124 | df_len = len(df.index) 125 | 126 | self.lens.append(df_len - self.sq + 1) 127 | self.df_list.append(df) 128 | print("{} len: {}".format(csv_path, df_len)) 129 | print("\n") 130 | 131 | def __len__(self): 132 | return sum(self.lens) 133 | 134 | def __getitem__(self, index): 135 | rel_index = index 136 | 137 | # 判断当前样本在哪个df中, index从0开始 138 | for ix in range(len(self.lens)): 139 | if rel_index < self.lens[ix]: 140 | break 141 | else: 142 | rel_index -= self.lens[ix] 143 | 144 | #print("sample {} use label:{} relative index: {}".format(index, self.label_path_list[ix], rel_index)) 145 | 146 | images, heatmaps = self._get_sample(self.image_dir_list[ix], self.df_list[ix], rel_index) 147 | 148 | return images, heatmaps 149 | 150 | 151 | def _get_sample(self, image_dir, label_data, image_rel_index): 152 | images = [] 153 | heatmaps = [] 154 | 155 | w = self.imgsz[1] 156 | h = self.imgsz[0] 157 | 158 | rd_state = random.getstate() 159 | rd_seed = time.time() 160 | for i in range(self.sq): 161 | #if (int(label_data['frame_num'][image_rel_index+i]) != int(image_rel_index+i)): 162 | # print(image_dir) 163 | # print(label_data['frame_num']) 164 | # print("{} ---> {}".format(label_data['frame_num'][image_rel_index+i], image_rel_index+i)) 165 | # assert(int(label_data['frame_num'][image_rel_index+i]) == int(image_rel_index+i)) 166 | 167 | 168 | image_path = image_dir + "/" + str(label_data['frame_num'][image_rel_index+i]) + ".jpg" 169 | img = cv2.imread(image_path) # BGR 170 | 171 | interp = cv2.INTER_LINEAR if (self.augment) else cv2.INTER_AREA 172 | img = cv2.resize(img, (w, h), interpolation=interp) 173 | 174 | 175 | visible = label_data['visible'][image_rel_index+i] 176 | x = label_data['x'][image_rel_index+i] 177 | y = label_data['y'][image_rel_index+i] 178 | 179 | kps_int = np.array([int(w*x), int(h*y), int(visible)]).reshape(1, -1) 180 | kps_xy = kps_int[:, :2] 181 | assert(len(kps_xy) == 1) 182 | 183 | if self.augment: 184 | # 使用系统时间作为种子值, 保证多张图片的增强策略一致 185 | random.seed(rd_seed) 186 | 187 | img, kps_xy = random_perspective(img, kps_xy) 188 | 189 | img, kps_xy = self.albumentations(img, kps_xy) 190 | 191 | augment_hsv(img, hgain=0.015, sgain=0.7, vgain=0.4) 192 | 193 | img, kps_xy = random_flip(img, kps_xy) 194 | 195 | # kps_int will return 196 | kps_int[:, :2] = kps_xy 197 | kps_int = kps_int.astype(int) 198 | 199 | 200 | if visible == 0: 201 | heatmap = self._gen_heatmap(w, h, -1, -1) 202 | else: 203 | heatmap = self._gen_heatmap(w, h, int(w*x), int(h*y)) 204 | 205 | # x, y, visible 206 | if kps_int[0][2] == 0: 207 | heatmap = self._gen_heatmap(w, h, -1, -1) 208 | else: 209 | x = kps_int[0][0] 210 | y = kps_int[0][1] 211 | 212 | heatmap = self._gen_heatmap(w, h, x, y) 213 | 214 | 215 | img = ToTensor()(img) 216 | 217 | images.append(img) 218 | heatmaps.append(heatmap) 219 | 220 | random.setstate(rd_state) 221 | 222 | images = torch.concatenate(images) # 平铺RGB维度 223 | heatmaps = torch.tensor(np.array(heatmaps), requires_grad=False, dtype=torch.float32) 224 | 225 | return images, heatmaps 226 | 227 | 228 | def _gen_heatmap(self, w, h, cx, cy, r=2.5, mag=1): 229 | if cx < 0 or cy < 0: 230 | return np.zeros((h, w)) 231 | x, y = np.meshgrid(np.linspace(1, w, w), np.linspace(1, h, h)) 232 | heatmap = ((y - (cy + 1))**2) + ((x - (cx + 1))**2) 233 | heatmap[heatmap <= r**2] = 1 234 | heatmap[heatmap > r**2] = 0 235 | 236 | return heatmap*mag 237 | 238 | 239 | def _make_gaussian(size=(1920, 1080), center=(0.5, 0.5), fwhm=(5, 5)): 240 | """ Make a square gaussian kernel. 241 | 242 | size: side of the square 243 | center: central point 244 | fwhm: Diameter 245 | 246 | source: https://stackoverflow.com/questions/7687679/how-to-generate-2d-gaussian-with-python 247 | """ 248 | 249 | x = np.arange(0, size[0], 1, float) 250 | y = np.arange(0, size[1], 1, float)[:,np.newaxis] 251 | 252 | x0 = size[0]*center[0] 253 | y0 = size[1]*center[1] 254 | 255 | return np.exp(-4*np.log(2) * ((x-x0)**2/fwhm[0]**2 + (y-y0)**2/fwhm[1]**2)) 256 | 257 | 258 | if __name__ == "__main__": 259 | if not os.path.exists('./runs/loader_test'): 260 | os.makedirs('./runs/loader_test') 261 | 262 | batch_size = 1 263 | test_loader = create_dataloader("./example_dataset/match/images/1_10_12", batch_size=batch_size) 264 | 265 | for index, (_images,_heatmaps) in enumerate(test_loader): 266 | jj = 0 267 | # for jj in range(batch_size): 268 | hms = [] 269 | for ii in range(3): 270 | hm = _heatmaps[jj,ii,:,:].repeat(3,1,1) # 奇怪,为什么不用*255就能直接得到灰度图像 271 | hms.append(hm) 272 | # torchvision.utils.save_image(hms, './runs/loader_test/batch{}_{}_heatmap.png'.format(index, jj)) 273 | 274 | ims = [] 275 | for ii in range(3): 276 | im = _images[jj,(0+ii*3,1+ii*3,2+ii*3),:,:] # 奇怪,为什么不用*255就能直接得到彩色图像 277 | ims.append(im) 278 | hms.append(im) 279 | # torchvision.utils.save_image(ims, './runs/loader_test/batch{}_{}_image.png'.format(index, jj)) 280 | 281 | hms = torchvision.utils.make_grid(hms, nrow=3) 282 | torchvision.utils.save_image(hms, './runs/loader_test/batch{}_{}.png'.format(index, jj)) 283 | 284 | if index >= 10: 285 | break 286 | 287 | 288 | -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from pathlib import Path 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | 8 | FILE = Path(__file__).resolve() 9 | ROOT = FILE.parents[1] # YOLOv5 root directory 10 | 11 | 12 | def yaml_load(file): 13 | # Single-line safe yaml loading 14 | with open(file, errors='ignore') as f: 15 | return yaml.safe_load(f) 16 | 17 | def check_dataset(data): 18 | if isinstance(data, (str, Path)): 19 | data = yaml_load(data) # dictionary 20 | 21 | path = Path(data.get('path')) # optional 'path' default to '.' 22 | if not path.is_absolute(): 23 | path = (ROOT / path).resolve() 24 | data['path'] = path # download scripts 25 | 26 | for k in 'train', 'val': 27 | if data.get(k): # prepend path 28 | if isinstance(data[k], str): 29 | x = (path / data[k]).resolve() 30 | if not x.exists() and data[k].startswith('../'): 31 | x = (path / data[k][3:]).resolve() 32 | data[k] = str(x) 33 | else: 34 | data[k] = [str((path / x).resolve()) for x in data[k]] 35 | 36 | return data 37 | 38 | 39 | # img: 0/1 binary image, numpy array. 40 | def get_shuttle_position(img): 41 | if np.amax(img) <= 0: 42 | # (visible, cx, cy) 43 | return (0, 0, 0) 44 | 45 | else: 46 | (cnts, _) = cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 47 | assert (len(cnts) != 0) 48 | 49 | rects = [cv2.boundingRect(ctr) for ctr in cnts] 50 | max_area_idx = 0 51 | max_area = rects[max_area_idx][2] * rects[max_area_idx][3] 52 | 53 | for ii in range(len(rects)): 54 | area = rects[ii][2] * rects[ii][3] 55 | if area > max_area: 56 | max_area_idx = ii 57 | max_area = area 58 | 59 | target = rects[max_area_idx] 60 | (cx, cy) = (int(target[0] + target[2] / 2), int(target[1] + target[3] / 2)) 61 | 62 | # (visible, cx, cy) 63 | return (1, cx, cy) 64 | 65 | 66 | def outcome(y_pred, y_true, tol=3): # [batch, 3, h, w] 67 | n = y_pred.shape[0] 68 | i = 0 69 | TP = TN = FP1 = FP2 = FN = 0 70 | while i < n: 71 | for j in range(3): 72 | if np.amax(y_pred[i][j]) == 0 and np.amax(y_true[i][j]) == 0: 73 | TN += 1 74 | elif np.amax(y_pred[i][j]) > 0 and np.amax(y_true[i][j]) == 0: 75 | FP2 += 1 76 | elif np.amax(y_pred[i][j]) == 0 and np.amax(y_true[i][j]) > 0: 77 | FN += 1 78 | elif np.amax(y_pred[i][j]) > 0 and np.amax(y_true[i][j]) > 0: 79 | h_pred = y_pred[i][j] * 255 80 | h_true = y_true[i][j] * 255 81 | h_pred = h_pred.astype('uint8') 82 | h_true = h_true.astype('uint8') 83 | 84 | (_, cx_pred, cy_pred) = get_shuttle_position(h_pred) 85 | (_, cx_true, cy_true) = get_shuttle_position(h_true) 86 | 87 | dist = np.sqrt(pow(cx_pred-cx_true, 2)+pow(cy_pred-cy_true, 2)) 88 | 89 | if dist > tol: 90 | FP1 += 1 91 | else: 92 | TP += 1 93 | i += 1 94 | return (TP, TN, FP1, FP2, FN) 95 | 96 | 97 | def evaluation(TP, TN, FP1, FP2, FN): 98 | try: 99 | accuracy = (TP + TN) / (TP + TN + FP1 + FP2 + FN) 100 | except: 101 | accuracy = 0 102 | try: 103 | precision = TP / (TP + FP1 + FP2) 104 | except: 105 | precision = 0 106 | try: 107 | recall = TP / (TP + FN) 108 | except: 109 | recall = 0 110 | 111 | return (accuracy, precision, recall) 112 | 113 | def tensorboard_log(log_writer, type, avg_loss, TP, TN, FP1, FP2, FN, epoch): 114 | log_writer.add_scalar('{}/loss'.format(type), avg_loss, epoch) 115 | log_writer.add_scalar('{}/TP'.format(type), TP, epoch) 116 | log_writer.add_scalar('{}/TN'.format(type), TN, epoch) 117 | log_writer.add_scalar('{}/FP1'.format(type), FP1, epoch) 118 | log_writer.add_scalar('{}/FP2'.format(type), FP2, epoch) 119 | log_writer.add_scalar('{}/FN'.format(type), FN, epoch) 120 | log_writer.add_scalar('{}/TP'.format(type), TP, epoch) 121 | 122 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN) 123 | 124 | log_writer.add_scalar('{}/Accuracy'.format(type), accuracy, epoch) 125 | log_writer.add_scalar('{}/precision'.format(type), precision, epoch) 126 | log_writer.add_scalar('{}/precision'.format(type), precision, epoch) -------------------------------------------------------------------------------- /val.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import os 5 | import sys 6 | import cv2 7 | import numpy as np 8 | from pathlib import Path 9 | from tqdm import tqdm 10 | from argparse import ArgumentParser 11 | 12 | from models.tracknet import TrackNet 13 | from utils.dataloaders import create_dataloader 14 | from utils.general import check_dataset, outcome, evaluation 15 | 16 | 17 | # from yolov5 detect.py 18 | FILE = Path(__file__).resolve() 19 | ABS_ROOT = FILE.parents[0] # YOLOv5 root directory 20 | if str(ABS_ROOT) not in sys.path: 21 | sys.path.append(str(ABS_ROOT)) # add ROOT to PATH 22 | ROOT = Path(os.path.relpath(ABS_ROOT, Path.cwd())) # relative 23 | 24 | 25 | 26 | 27 | def wbce_loss(y_true, y_pred): 28 | return -1*( 29 | ((1-y_pred)**2) * y_true * torch.log(torch.clamp(y_pred, min=1e-07, max=1)) + 30 | (y_pred**2) * (1-y_true) * torch.log(torch.clamp(1-y_pred, min=1e-07, max=1)) 31 | ).sum() 32 | 33 | 34 | def validation_loop(device, model, val_loader, save_dir): 35 | model.eval() 36 | 37 | loss_sum = 0 38 | TP = TN = FP1 = FP2 = FN = 0 39 | 40 | with torch.inference_mode(): 41 | pbar = tqdm(val_loader, ncols=180) 42 | for batch_index, (X, y) in enumerate(pbar): 43 | X, y = X.to(device), y.to(device) 44 | y_pred = model(X) 45 | 46 | loss_sum += wbce_loss(y, y_pred).item() 47 | 48 | y_ = y.detach().cpu().numpy() 49 | y_pred_ = y_pred.detach().cpu().numpy() 50 | 51 | y_pred_ = (y_pred_ > 0.5).astype('float32') 52 | (tp, tn, fp1, fp2, fn) = outcome(y_pred_, y_) 53 | TP += tp 54 | TN += tn 55 | FP1 += fp1 56 | FP2 += fp2 57 | FN += fn 58 | 59 | (accuracy, precision, recall) = evaluation(TP, TN, FP1, FP2, FN) 60 | 61 | pbar.set_description('Val loss: {:.6f} | TP: {}, TN: {}, FP1: {}, FP2: {}, FN: {} | Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format( \ 62 | loss_sum / ((batch_index+1)*X.shape[0]), TP, TN, FP1, FP2, FN, accuracy, precision, recall)) 63 | 64 | F1 = 2 * (precision*recall) / (precision+recall) 65 | print("F1-score: {}".format(F1)) 66 | 67 | return loss_sum/len(val_loader) 68 | 69 | 70 | 71 | def parse_opt(): 72 | parser = ArgumentParser() 73 | 74 | parser.add_argument('--data', type=str, default=ROOT / 'data/match/test.yaml', help='Path to dataset.') 75 | parser.add_argument('--weights', type=str, default=ROOT / 'best.pt', help='Path to trained model weights.') 76 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[288, 512], help='image size h,w') 77 | parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs, -1 for autobatch') 78 | parser.add_argument('--project', default=ROOT / 'runs/val', help='save results to project/name') 79 | 80 | opt = parser.parse_args() 81 | return opt 82 | 83 | 84 | def main(opt): 85 | d_save_dir = str(opt.project) 86 | f_weights = str(opt.weights) 87 | batch_size = opt.batch_size 88 | f_data = str(opt.data) 89 | imgsz = opt.imgsz 90 | 91 | 92 | data_dict = check_dataset(f_data) 93 | train_path, val_path = data_dict['train'], data_dict['val'] 94 | 95 | if not os.path.exists(d_save_dir): 96 | os.makedirs(d_save_dir) 97 | 98 | 99 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 100 | model = TrackNet().to(device) 101 | 102 | assert os.path.exists(f_weights), f_weights+" is invalid" 103 | print("load pretrain weights {}".format(f_weights)) 104 | model.load_state_dict(torch.load(f_weights)) 105 | 106 | 107 | val_loader = create_dataloader(val_path, imgsz, batch_size=batch_size) 108 | 109 | validation_loop(device, model, val_loader, d_save_dir) 110 | 111 | 112 | if __name__ == '__main__': 113 | opt = parse_opt() 114 | main(opt) --------------------------------------------------------------------------------