├── models
├── pose_estimator
│ ├── cfg
│ │ ├── readme.txt
│ │ └── w32_256x256_adam_lr1e-3.yaml
│ ├── readme.txt
│ ├── model_weights
│ │ └── readme.txt
│ └── pose_estimator_model_setup.py
└── detectron2
│ ├── readme.txt
│ ├── model_weights
│ └── readme.txt
│ └── detectors.py
├── rule_based_programs
├── microprograms
│ ├── readme.txt
│ ├── temporal_segmentation_functions.py
│ ├── dive_error_functions.py
│ └── dive_recognition_functions.py
├── README.md
├── aqa_metaProgram_finediving.py
├── aqa_metaProgram.py
└── scoring_functions.py
├── score_report_generation
├── templates
│ ├── readme.txt
│ └── report_template_tables.html
└── generate_report_functions.py
├── teaser_fig.png
├── README.md
├── nsaqa.py
└── nsaqa_finediving.py
/models/pose_estimator/cfg/readme.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/rule_based_programs/microprograms/readme.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/score_report_generation/templates/readme.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/teaser_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/laurenok24/NSAQA/HEAD/teaser_fig.png
--------------------------------------------------------------------------------
/models/detectron2/readme.txt:
--------------------------------------------------------------------------------
1 | You will need to install the detectron2 model: https://github.com/facebookresearch/detectron2
2 |
--------------------------------------------------------------------------------
/models/pose_estimator/readme.txt:
--------------------------------------------------------------------------------
1 | You will need to install the HRNet model: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
2 |
3 | Additionally, you will need to add pose_hrnet.py (found at https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/blob/master/lib/models/pose_hrnet.py) to this folder.
4 |
--------------------------------------------------------------------------------
/models/pose_estimator/model_weights/readme.txt:
--------------------------------------------------------------------------------
1 | Download model weights for HRNet: https://drive.google.com/drive/folders/1sdiGj9lnhNi0-Nix-wSSQeMgYHPMSFCQ?usp=sharing
2 |
3 | If you wish to train the models on your own, here are the annotations we hand annotated for training: https://drive.google.com/drive/folders/1_lKxYwdvWRqxt3YkV4sABbbY7fWhaQSA?usp=sharing
4 |
--------------------------------------------------------------------------------
/models/detectron2/model_weights/readme.txt:
--------------------------------------------------------------------------------
1 | Download model weights for platform, splash, and diver detection:
2 | https://drive.google.com/drive/folders/1sdiGj9lnhNi0-Nix-wSSQeMgYHPMSFCQ?usp=sharing
3 |
4 | If you wish to train the models on your own, here are the annotations we hand annotated for training: https://drive.google.com/drive/folders/1_lKxYwdvWRqxt3YkV4sABbbY7fWhaQSA?usp=sharing
5 |
--------------------------------------------------------------------------------
/rule_based_programs/README.md:
--------------------------------------------------------------------------------
1 | Run all code from the root.
2 |
3 | aqa_metaProgram.py:
4 | Extract information from a dive clip necessary for scoring. Extracted data will be saved as a pickle file.
5 | ```
6 | python rule_based_programs/aqa_metaProgram.py path/to/video.mp4
7 | ```
8 |
9 | aqa_metaProgram_finediving.py:
10 | Extract information from dive frames in the FineDiving dataset found at: https://github.com/xujinglin/FineDiving/tree/main.
11 | Each dive in this dataset has an instance id (x, y).
12 | ```
13 | python rule_based_programs/aqa_metaProgram_finediving.py x y
14 |
15 | # e.g. for instance id ('01', 1)
16 | python rule_based_programs/aqa_metaProgram_finediving.py 01 1
17 | ```
18 |
19 | scoring_functions.py:
20 | You'll need to download [distribution_data.pkl](https://drive.google.com/file/d/1Dbc7LVeuo3d8RVtx0fWdjgoqjbNCtnuk/view?usp=sharing) in order to generate scores. This file contains the distribution of data (from the dives in the FineDiving dataset) used to calculate the percentile scores.
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ⚠️ ***Please note that all materials (codes, dataset, etc.) are available for non-commercial use only.*** ⚠️ Please refrain from using the materials for any commercial purposes in any form or capacity.
2 |
3 | # Neuro-Symbolic Action Quality Assessment (NS-AQA)
4 | ## 🏆 CVPR 2024 CVSports Best Paper Award
5 | This repository contains the Python code implementation of NS-AQA for platform diving.
6 |
7 | 📝 **Technical Paper:** [link](https://arxiv.org/abs/2403.13798)
8 |
9 | 🤗 **Huggingface Demo:** [link](https://huggingface.co/spaces/X-NS/NSAQA)
10 |
11 | ## Overview
12 | We propose a neuro-symbolic paradigm for AQA.
13 | 
14 |
15 | ## Run NS-AQA
16 | Score Report is saved as an HTML file at "./output/"
17 |
18 | Run NS-AQA on a single dive clip.
19 | ```
20 | python nsaqa.py path/to/video.mp4
21 | ```
22 | Run NS-AQA on a single dive from the [FineDiving Dataset](https://github.com/xujinglin/FineDiving). Each dive in the dataset has an instance id (x, y).
23 | ```
24 | python nsaqa_finediving.py x y
25 |
26 | # e.g. if the instance id is ('01', 1)
27 | python nsaqa_finediving.py 01 1
28 | ```
29 |
30 | ## Please consider citing our work:
31 | ```
32 | @inproceedings{okamoto2024hierarchical,
33 | title={Hierarchical NeuroSymbolic Approach for Comprehensive and Explainable Action Quality Assessment},
34 | author={Okamoto, Lauren and Parmar, Paritosh},
35 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
36 | pages={3204--3213},
37 | year={2024}
38 | }
39 | ```
40 |
--------------------------------------------------------------------------------
/nsaqa.py:
--------------------------------------------------------------------------------
1 | """
2 | nsaqa.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | import pickle
7 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector
8 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model
9 | from rule_based_programs.scoring_functions import *
10 | from score_report_generation.generate_report_functions import *
11 | from rule_based_programs.aqa_metaProgram import aqa_metaprogram, abstractSymbols, extract_frames
12 | import argparse
13 |
14 | def main(video_path):
15 | platform_detector = get_platform_detector()
16 | splash_detector = get_splash_detector()
17 | diver_detector = get_diver_detector()
18 | pose_model = get_pose_model()
19 | template_path = 'report_template_tables.html'
20 | dive_data = {}
21 |
22 | frames = extract_frames(video_path)
23 | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
24 | dive_data = aqa_metaprogram(frames, dive_data, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
25 | intermediate_scores = get_all_report_scores(dive_data)
26 | html = generate_report_from_frames(template_path, intermediate_scores, frames)
27 | save_path = "./output/{}_report.html".format("".join(video_path.split('.')[:-1]))
28 | with open(save_path, 'w') as f:
29 | print("saving html report into " + save_path)
30 | f.write(html)
31 |
32 |
33 | if __name__ == '__main__':
34 | # Set up command-line arguments
35 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.")
36 | new_parser.add_argument("video_path", type=str, help="Path to dive video (mp4 format).")
37 | meta_program_args = new_parser.parse_args()
38 | video_path = meta_program_args.video_path
39 |
40 | main(video_path)
41 |
--------------------------------------------------------------------------------
/nsaqa_finediving.py:
--------------------------------------------------------------------------------
1 | """
2 | nsaqa_finediving.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | import pickle
7 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector
8 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model
9 | from rule_based_programs.scoring_functions import *
10 | from score_report_generation.generate_report_functions import *
11 | from rule_based_programs.aqa_metaProgram_finediving import aqa_metaprogram_finediving
12 | import argparse
13 |
14 | def main(key):
15 | platform_detector = get_platform_detector()
16 | splash_detector = get_splash_detector()
17 | diver_detector = get_diver_detector()
18 | pose_model = get_pose_model()
19 |
20 | # Fine-grained annotations from FineDiving Dataset
21 | with open('FineDiving/Annotations/fine-grained_annotation_aqa.pkl', 'rb') as f:
22 | dive_annotation_data = pickle.load(f)
23 | diveNum = dive_annotation_data[key][0]
24 | template_path = 'report_template_tables.html'
25 | dive_data = {}
26 |
27 | dive_data = aqa_metaprogram_finediving(key[0], key[1], diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
28 | intermediate_scores = get_all_report_scores(dive_data)
29 |
30 | local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1])
31 | html = generate_report(template_path, intermediate_scores, local_directory)
32 | save_path = "./output/{}_{}_report.pkl".format(key[0], key[1])
33 | with open(save_path, 'w') as f:
34 | print("saving html report into " + save_path)
35 | f.write(html)
36 |
37 |
38 | if __name__ == '__main__':
39 | # Set up command-line arguments
40 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.")
41 | new_parser.add_argument("FineDiving_key", type=str, nargs=2, help="key from FineDiving Dataset (e.g. 01 1)")
42 | meta_program_args = new_parser.parse_args()
43 | key = tuple(meta_program_args.FineDiving_key)
44 | key = (key[0], int(key[1]))
45 | print(key)
46 |
47 | main(key)
48 |
--------------------------------------------------------------------------------
/models/detectron2/detectors.py:
--------------------------------------------------------------------------------
1 | """
2 | detectors.py
3 | Author: Lauren Okamoto
4 |
5 | Code used to initialize the object detector models to be used for inference.
6 | """
7 |
8 | import sys, os, distutils.core
9 | sys.path.insert(0, os.path.abspath('./detectron2'))
10 |
11 | import detectron2
12 | import cv2
13 |
14 | from detectron2.utils.logger import setup_logger
15 | setup_logger()
16 |
17 | from detectron2 import model_zoo
18 | from detectron2.engine import DefaultPredictor
19 | from detectron2.config import get_cfg
20 | from detectron2.utils.visualizer import Visualizer
21 | from detectron2.data import MetadataCatalog, DatasetCatalog
22 | from detectron2.checkpoint import DetectionCheckpointer
23 | from detectron2.data.datasets import register_coco_instances
24 |
25 |
26 | def get_diver_detector():
27 | cfg = get_cfg()
28 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
29 | cfg.OUTPUT_DIR = "./models/detectron2/model_weights/"
30 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "diver_model_final.pth")
31 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
32 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1
33 | diver_detector = DefaultPredictor(cfg)
34 | return diver_detector
35 |
36 |
37 | def get_platform_detector():
38 | cfg = get_cfg()
39 | cfg.OUTPUT_DIR = "./models/detectron2/model_weights/"
40 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
41 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
42 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "plat_model_final.pth")
43 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
44 | platform_detector = DefaultPredictor(cfg)
45 | return platform_detector
46 |
47 | def get_splash_detector():
48 | cfg = get_cfg()
49 | cfg.OUTPUT_DIR = "./models/detectron2/model_weights/"
50 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
51 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
52 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "splash_model_final.pth")
53 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
54 | splash_detector = DefaultPredictor(cfg)
55 | return splash_detector
--------------------------------------------------------------------------------
/models/pose_estimator/cfg/w32_256x256_adam_lr1e-3.yaml:
--------------------------------------------------------------------------------
1 | AUTO_RESUME: true
2 | CUDNN:
3 | BENCHMARK: true
4 | DETERMINISTIC: false
5 | ENABLED: true
6 | DATA_DIR: ''
7 | GPUS: (0,)
8 | OUTPUT_DIR: 'output'
9 | LOG_DIR: 'log'
10 | WORKERS: 8
11 | PRINT_FREQ: 100
12 |
13 | DATASET:
14 | COLOR_RGB: true
15 | DATASET: mpii
16 | DATA_FORMAT: jpg
17 | FLIP: true
18 | NUM_JOINTS_HALF_BODY: 8
19 | PROB_HALF_BODY: -1.0
20 | ROOT: 'deep-high-resolution-net.pytorch/data/mpii/'
21 | ROT_FACTOR: 30
22 | SCALE_FACTOR: 0.25
23 | TEST_SET: valid
24 | TRAIN_SET: train_all
25 | MODEL:
26 | INIT_WEIGHTS: true
27 | NAME: pose_hrnet
28 | NUM_JOINTS: 16
29 | PRETRAINED: 'deep-high-resolution-net.pytorch/models/pytorch/pose_mpii/pose_hrnet_w32_256x256.pth'
30 | TARGET_TYPE: gaussian
31 | IMAGE_SIZE:
32 | - 256
33 | - 256
34 | HEATMAP_SIZE:
35 | - 64
36 | - 64
37 | SIGMA: 2
38 | EXTRA:
39 | PRETRAINED_LAYERS:
40 | - 'conv1'
41 | - 'bn1'
42 | - 'conv2'
43 | - 'bn2'
44 | - 'layer1'
45 | - 'transition1'
46 | - 'stage2'
47 | - 'transition2'
48 | - 'stage3'
49 | - 'transition3'
50 | - 'stage4'
51 | FINAL_CONV_KERNEL: 1
52 | STAGE2:
53 | NUM_MODULES: 1
54 | NUM_BRANCHES: 2
55 | BLOCK: BASIC
56 | NUM_BLOCKS:
57 | - 4
58 | - 4
59 | NUM_CHANNELS:
60 | - 32
61 | - 64
62 | FUSE_METHOD: SUM
63 | STAGE3:
64 | NUM_MODULES: 4
65 | NUM_BRANCHES: 3
66 | BLOCK: BASIC
67 | NUM_BLOCKS:
68 | - 4
69 | - 4
70 | - 4
71 | NUM_CHANNELS:
72 | - 32
73 | - 64
74 | - 128
75 | FUSE_METHOD: SUM
76 | STAGE4:
77 | NUM_MODULES: 3
78 | NUM_BRANCHES: 4
79 | BLOCK: BASIC
80 | NUM_BLOCKS:
81 | - 4
82 | - 4
83 | - 4
84 | - 4
85 | NUM_CHANNELS:
86 | - 32
87 | - 64
88 | - 128
89 | - 256
90 | FUSE_METHOD: SUM
91 | LOSS:
92 | USE_TARGET_WEIGHT: true
93 | TRAIN:
94 | BATCH_SIZE_PER_GPU: 64
95 | SHUFFLE: true
96 | BEGIN_EPOCH: 0
97 | END_EPOCH: 210
98 | OPTIMIZER: adam
99 | LR: 0.001
100 | LR_FACTOR: 0.1
101 | LR_STEP:
102 | - 170
103 | - 200
104 | WD: 0.0001
105 | GAMMA1: 0.99
106 | GAMMA2: 0.0
107 | MOMENTUM: 0.9
108 | NESTEROV: false
109 | TEST:
110 | BATCH_SIZE_PER_GPU: 64
111 | MODEL_FILE: './models/pose_estimator/model_weights/diver_pose_model_best.pth'
112 | FLIP_TEST: true
113 | POST_PROCESS: true
114 | SHIFT_HEATMAP: true
115 | DEBUG:
116 | DEBUG: true
117 | SAVE_BATCH_IMAGES_GT: true
118 | SAVE_BATCH_IMAGES_PRED: true
119 | SAVE_HEATMAPS_GT: true
120 | SAVE_HEATMAPS_PRED: true
121 |
--------------------------------------------------------------------------------
/rule_based_programs/microprograms/temporal_segmentation_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | temporal_segmentation_functions.py
3 | Author: Lauren Okamoto
4 |
5 | Temporal Segmentation Microprograms for detecting start/takeoff, somersault, twist, and entry phases
6 | """
7 |
8 | from rule_based_programs.microprograms.dive_error_functions import get_splash_from_one_frame, applyPositionTightnessError
9 | import numpy as np
10 |
11 | """
12 | Start/Takeoff Phase Microprogram
13 |
14 | Parameters:
15 | - filepath (str): file path where the frame is located
16 | - above_board (bool): True if the diver is above the board at this frame
17 | - on_board (bool): True if the diver is on the board at this frame
18 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected)
19 |
20 | Returns:
21 | - 0 if frame is not in start/takeoff phase
22 | - 1 if frame is in start/takeoff phase
23 | """
24 | def takeoff_microprogram_one_frame(filepath, above_board, on_board, pose_pred=None):
25 | if not above_board:
26 | return 0
27 | if on_board:
28 | return 1
29 | return 0
30 |
31 | """
32 | Somersault Phase Microprogram
33 |
34 | Parameters:
35 | - filepath (str): file path where the frame is located
36 | - on_board (bool): True if the diver is on the board at this frame
37 | - expected_som (int): number of somersaults in full dive (from action recognition)
38 | - half_som_count (int): number of somersaults counted by this frame
39 | - expected_twists (int): number of twists in full dive (from action recognition)
40 | - petal_count (int): number of twists counted by this frame
41 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected)
42 | - diver_detector: diver detector model
43 | - pose_model: pose estimator model
44 |
45 | Returns:
46 | - 0 if frame is not in somersault phase
47 | - 1 if frame is in somersault phase
48 | """
49 | def somersault_microprogram_one_frame(filepath, on_board, expected_som, half_som_count, expected_twists, petal_count, pose_pred=None, diver_detector=None, pose_model=None):
50 | if on_board:
51 | return 0
52 | if expected_som <= half_som_count:
53 | return 0
54 | # if not done with som or twists, need to determine if som or twist
55 | angle = applyPositionTightnessError(filepath, pose_pred, diver_detector=diver_detector, pose_model=pose_model)
56 | if angle is None:
57 | return 0
58 | # if not done with som but done with twists
59 | if expected_som > half_som_count and expected_twists <= petal_count:
60 | return 1
61 | # print("angle:", angle)
62 | if angle <= 80:
63 | return 1
64 | else:
65 | return 0
66 |
67 | """
68 | Twist Phase Microprogram
69 |
70 | Parameters:
71 | - filepath (str): file path where the frame is located
72 | - on_board (bool): True if the diver is on the board at this frame
73 | - expected_twists (int): number of twists in full dive (from action recognition)
74 | - petal_count (int): number of twists counted by this frame
75 | - expected_som (int): number of somersaults in full dive (from action recognition)
76 | - half_som_count (int): number of somersaults counted by this frame
77 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected)
78 | - diver_detector: diver detector model
79 | - pose_model: pose estimator model
80 |
81 | Returns:
82 | - 0 if frame is not in twist phase
83 | - 1 if frame is in twist phase
84 | """
85 | def twist_microprogram_one_frame(filepath, on_board, expected_twists, petal_count, expected_som, half_som_count, pose_pred=None, diver_detector=None, pose_model=None):
86 | if on_board:
87 | return 0
88 | if expected_twists <= petal_count or expected_som <= half_som_count:
89 | return 0
90 | angle = applyPositionTightnessError(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)
91 | if angle is None:
92 | return 0
93 | if angle > 80:
94 | return 1
95 | else:
96 | return 0
97 |
98 | """
99 | Entry Phase Microprogram
100 |
101 | Parameters:
102 | - filepath (str): file path where the frame is located
103 | - above_board (bool): True if the diver is above the board at this frame
104 | - on_board (bool): True if the diver is on the board at this frame
105 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected)
106 | - expected_twists (int): number of twists in full dive (from action recognition)
107 | - petal_count (int): number of twists counted by this frame
108 | - expected_som (int): number of somersaults in full dive (from action recognition)
109 | - half_som_count (int): number of somersaults counted by this frame
110 | - frame: Pil Image of frame
111 | - splash_detector: splash detector model
112 | - visualize: True if you want to save the splash segmentation mask prediction to an image
113 | - dive_folder_num: if visualize is true, this is where the image will be saved
114 |
115 | Returns:
116 | - 0 if frame is not in entry phase
117 | - 1 if frame is in entry phase
118 | """
119 | def entry_microprogram_one_frame(filepath, above_board, on_board, pose_pred, expected_twists, petal_count, expected_som, half_som_count, frame=None, splash_detector=None, visualize=False, dive_folder_num=None):
120 | if above_board:
121 | return 0
122 | if on_board:
123 | return 0
124 | splash = get_splash_from_one_frame(filepath, im=frame, predictor=splash_detector, visualize=visualize, dive_folder_num=dive_folder_num)
125 | if splash:
126 | return 1
127 | # if completed with somersaults, we know we're in entry phase
128 | if not expected_som > half_som_count:
129 | return 1
130 | if expected_twists > petal_count or expected_som > half_som_count:
131 | return 0
132 | return 1
133 |
--------------------------------------------------------------------------------
/models/pose_estimator/pose_estimator_model_setup.py:
--------------------------------------------------------------------------------
1 | """
2 | pose_estimator_model_setup.py
3 | Author: Lauren Okamoto
4 |
5 | Code used to initialize the pose estimator model, HRNet, combined with the diver detectron2 model to be used for diver pose inference.
6 | """
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import argparse
13 | import csv
14 | import os
15 | import shutil
16 | import sys
17 |
18 | from PIL import Image
19 | import torch
20 | import torch.nn.parallel
21 | import torch.backends.cudnn as cudnn
22 | import torch.optim
23 | import torch.utils.data
24 | import torch.utils.data.distributed
25 | import torchvision.transforms as transforms
26 | import torchvision
27 | import cv2
28 | import numpy as np
29 | import time
30 | sys.path.append('./deep-high-resolution-net.pytorch/lib')
31 | import models
32 | from config import cfg
33 | from config import update_config
34 | from core.function import get_final_preds
35 | from utils.transforms import get_affine_transform
36 |
37 | import distutils.core
38 |
39 | from models.detectron2.diver_detector_setup import get_diver_detector
40 | from models.pose_estimator.pose_hrnet import get_pose_net
41 |
42 |
43 | def box_to_center_scale(box, model_image_width, model_image_height):
44 | """convert a box to center,scale information required for pose transformation
45 | Parameters
46 | ----------
47 | box : list of tuple
48 | list of length 2 with two tuples of floats representing
49 | bottom left and top right corner of a box
50 | model_image_width : int
51 | model_image_height : int
52 |
53 | Returns
54 | -------
55 | (numpy array, numpy array)
56 | Two numpy arrays, coordinates for the center of the box and the scale of the box
57 | """
58 | center = np.zeros((2), dtype=np.float32)
59 | bottom_left_corner = (box[0].data.cpu().item(), box[1].data.cpu().item())
60 | top_right_corner = (box[2].data.cpu().item(), box[3].data.cpu().item())
61 | box_width = top_right_corner[0]-bottom_left_corner[0]
62 | box_height = top_right_corner[1]-bottom_left_corner[1]
63 | bottom_left_x = bottom_left_corner[0]
64 | bottom_left_y = bottom_left_corner[1]
65 | center[0] = bottom_left_x + box_width * 0.5
66 | center[1] = bottom_left_y + box_height * 0.5
67 | aspect_ratio = model_image_width * 1.0 / model_image_height
68 | pixel_std = 200
69 | if box_width > aspect_ratio * box_height:
70 | box_height = box_width * 1.0 / aspect_ratio
71 | elif box_width < aspect_ratio * box_height:
72 | box_width = box_height * aspect_ratio
73 | scale = np.array(
74 | [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
75 | dtype=np.float32)
76 | if center[0] != -1:
77 | scale = scale * 1.25
78 | return center, scale
79 |
80 | def parse_args():
81 | parser = argparse.ArgumentParser(description='Train keypoints network')
82 | parser.add_argument('--cfg', type=str, default='./models/pose_estimator/cfg/w32_256x256_adam_lr1e-3.yaml')
83 | parser.add_argument('opts',
84 | help='Modify config options using the command-line',
85 | default=None,
86 | nargs=argparse.REMAINDER)
87 | args = parser.parse_args()
88 | args.opts = ''
89 | args.modelDir = ''
90 | args.logDir = ''
91 | args.dataDir = ''
92 | args.prevModelDir = ''
93 | return args
94 |
95 | def get_pose_estimation_prediction(pose_model, image, center, scale):
96 | rotation = 0
97 | trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
98 | transform = transforms.Compose([
99 | transforms.ToTensor(),
100 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
101 | std=[0.229, 0.224, 0.225]),
102 | ])
103 | model_input = cv2.warpAffine(
104 | image,
105 | trans,
106 | (256, 256),
107 | flags=cv2.INTER_LINEAR)
108 |
109 | # pose estimation inference
110 | model_input = transform(model_input).unsqueeze(0)
111 | # switch to evaluate mode
112 | pose_model.eval()
113 | with torch.no_grad():
114 | # compute output heatmap
115 | output = pose_model(model_input)
116 | preds, _ = get_final_preds(
117 | cfg,
118 | output.clone().cpu().numpy(),
119 | np.asarray([center]),
120 | np.asarray([scale]))
121 | return preds
122 |
123 | def get_pose_model():
124 | CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
125 | cudnn.benchmark = cfg.CUDNN.BENCHMARK
126 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
127 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED
128 | args = parse_args()
129 | update_config(cfg, args)
130 | pose_model = get_pose_net(cfg, is_train=False)
131 | if cfg.TEST.MODEL_FILE:
132 | print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
133 | pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
134 | else:
135 | print('expected model defined in config at TEST.MODEL_FILE')
136 | pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)
137 | pose_model.to(CTX)
138 | pose_model.eval()
139 | return pose_model
140 |
141 | def get_pose_estimation(filepath, image_bgr=None, diver_detector=None, pose_model=None):
142 | if image_bgr is None:
143 | image_bgr = cv2.imread(filepath)
144 | if image_bgr is None:
145 | print("ERROR: image {} does not exist".format(filepath))
146 | return None
147 | if diver_detector is None:
148 | diver_detector = get_diver_detector()
149 |
150 | if pose_model is None:
151 | pose_model = get_pose_model()
152 |
153 | image = image_bgr[:, :, [2, 1, 0]]
154 |
155 | outputs = diver_detector(image_bgr)
156 | scores = outputs['instances'].scores
157 | pred_boxes = []
158 | if len(scores) > 0:
159 | pred_boxes = outputs['instances'].pred_boxes
160 |
161 | if len(pred_boxes) >= 1:
162 | for box in pred_boxes:
163 | center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
164 | image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
165 | box = box.detach().cpu().numpy()
166 | return box, get_pose_estimation_prediction(pose_model, image_pose, center, scale)
167 | return None, None
168 |
--------------------------------------------------------------------------------
/score_report_generation/templates/report_template_tables.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
9 |
10 |
11 |
12 | | Your dive was {{overall_score_desc}}, and scored a {{overall_score}}. Here is how we scored each component of the dive. Each percentile is relative to dives in the semifinals and finals of Olympic and World Championship competitions. |
13 |
14 |
15 |
16 |
17 | | Error |
18 | Description |
19 | Visuals |
20 | Score |
21 |
22 |
23 | | Feet Apart |
24 | We found that your leg separation angle was on average {{feet_apart_score}}° for your dive.
25 | This is rated as {{feet_apart_percentile}} percentile. |
26 |
27 | {% if has_feet_apart_peaks %}
28 |
29 |
30 | {%else%}
31 | There were no particular instances to show where your feet came apart.
32 | {%endif%}
33 | |
34 |
35 | {{feet_apart_percentile_divided_by_ten}}
36 | |
37 |
38 | {% if include_height_off_platform %}
39 |
40 | | Height off platform |
41 | Your jump was {{height_off_board_description}}, and was rated as {{height_off_board_percentile}} percentile. Here is the highest you jumped off the platform. |
42 |
43 |
44 | |
45 |
46 | {{height_off_board_percentile_divided_by_ten}}
47 | |
48 |
49 | {% endif %}
50 |
51 | | Distance from platform |
52 | You were {{dist_from_board_percentile}} the platform. Here is where you came closest to the platform. |
53 |
54 |
55 |
56 | |
57 |
58 | {{dist_from_board_percentile_status}}
59 | |
60 |
61 | {% if som_position_tightness_frames|length > 0 %}
62 |
63 | | Somersault tightness |
64 |
65 | We found that the tightness of your {{som_position_tightness_position}} was {{som_position_tightness_score}}° on average.
66 | This is rated as {{som_position_tightness_percentile}} percentile. Here are some examples of your position in the somersault.
67 | |
68 | {% if knee_bend_frames|length > 0 %}
69 |
70 |
71 |
72 | |
73 | {% else %}
74 |
75 |
76 |
77 | |
78 | {% endif %}
79 |
80 | {{som_position_tightness_percentile_divided_by_ten}}
81 | |
82 |
83 | {% endif %}
84 | {% if knee_bend_frames|length > 0 %}
85 |
86 | | Knee straightness |
87 |
88 | We found that your knees bent {{knee_bend_score}}° on average.
89 | This is rated as {{knee_bend_percentile}} percentile.
90 | |
91 |
92 | {{knee_bend_percentile_divided_by_ten}}
93 | |
94 |
95 | {% endif %}
96 | {% if is_twister %}
97 |
98 | |
99 | Twist Straightness
100 | |
101 |
102 | We found that the tightness of your {{twist_position_tightness_position}} was {{twist_position_tightness_score}}° on average.
103 | This is rated as {{twist_position_tightness_percentile}} percentile. Here are some examples of your position in the somersault.
104 | |
105 |
106 |
107 |
108 | |
109 |
110 | {{twist_position_tightness_percentile_divided_by_ten}}
111 | |
112 |
113 | {% endif %}
114 |
115 | |
116 | Verticalness (over/under rotation)
117 | |
118 |
119 | We found that you deviated from vertical by {{over_under_rotation_score}}°, which was the {{over_under_rotation_percentile}} percentile.
120 | |
121 |
122 |
123 |
124 | |
125 |
126 | {{over_under_rotation_percentile_divided_by_ten}}
127 | |
128 |
129 |
130 | | Body straightness during entry |
131 | The straightness of your body during entry deviated by {{straightness_during_entry_score}}°, which was the {{straightness_during_entry_percentile}} percentile. |
132 | {{straightness_during_entry_percentile_divided_by_ten}} |
133 |
134 |
135 | | Splash |
136 | Your splash was {{splash_description}} and rated as {{splash_percentile}} percentile. |
137 |
138 |
139 |
140 | |
141 |
142 | {{splash_percentile_divided_by_ten}}
143 | |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/rule_based_programs/microprograms/dive_error_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | dive_error_functions.py
3 | Author: Lauren Okamoto
4 |
5 | AQA functions.
6 | """
7 |
8 | import torch
9 | import numpy as np
10 | import math
11 | import cv2
12 | import sys, os
13 | from matplotlib import image
14 | from matplotlib import pyplot as plt
15 | from math import atan
16 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation
17 | from models.detectron2.detectors import get_platform_detector, get_splash_detector
18 | from detectron2.utils.visualizer import Visualizer
19 |
20 | ################## HELPER FUNCTIONS ##################
21 | def slope(x1, y1, x2, y2):
22 | if x1 == x2:
23 | return "undefined"
24 | return (y2-y1)/(x2-x1)
25 |
26 | # Function to find the angle between two lines
27 | def findAngle(M1, M2):
28 | vertical_line = False
29 | if M1 == "undefined":
30 | M1 = 0
31 | vertical_line = True
32 | if M2 == "undefined":
33 | M2 = 0
34 | vertical_line = True
35 | PI = 3.14159265
36 | angle = abs((M2 - M1) / (1 + M1 * M2))
37 | ret = atan(angle)
38 | val = (ret * 180) / PI
39 | if vertical_line:
40 | return 90 - round(val,4)
41 | return (round(val, 4))
42 |
43 | def find_which_side_board_on(output):
44 | pred_classes = output['instances'].pred_classes.cpu().numpy()
45 | platforms = np.where(pred_classes == 0)[0]
46 | scores = output['instances'].scores[platforms]
47 | if len(scores) == 0:
48 | return
49 | pred_masks = output['instances'].pred_masks[platforms]
50 | max_instance = torch.argmax(scores)
51 | pred_mask = np.array(pred_masks[max_instance].cpu())
52 | for i in range(len(pred_mask[0])//2):
53 | if sum(pred_mask[:, i]) != 0:
54 | return "left"
55 | elif sum(pred_mask[:, len(pred_mask[0]) - i - 1]) != 0:
56 | return "right"
57 | return None
58 |
59 | def board_end(output, board_side=None):
60 | pred_classes = output['instances'].pred_classes.cpu().numpy()
61 | platforms = np.where(pred_classes == 0)[0]
62 | scores = output['instances'].scores[platforms]
63 | if len(scores) == 0:
64 | return
65 | pred_masks = output['instances'].pred_masks[platforms]
66 | max_instance = torch.argmax(scores)
67 | pred_mask = np.array(pred_masks[max_instance].cpu()) # splash instance with highest confidence
68 | # need to figure out whether springboard is on left or right side of frame, then need to find where the edge is
69 | if board_side is None:
70 | board_side = find_which_side_board_on(output)
71 | if board_side == "left":
72 | for i in range(len(pred_mask[0]) - 1, -1, -1):
73 | if sum(pred_mask[:, i]) != 0:
74 | trues = np.where(pred_mask[:, i])[0]
75 | return (i, min(trues))
76 | if board_side == "right":
77 | for i in range(len(pred_mask[0])):
78 | if sum(pred_mask[:, i]) != 0:
79 | trues = np.where(pred_mask[:, i])[0]
80 | return (i, min(trues))
81 | return None
82 |
83 | # Splash helper function
84 | def get_splash_pred_mask(output):
85 | pred_classes = output['instances'].pred_classes.cpu().numpy()
86 | splashes = np.where(pred_classes == 0)[0]
87 | scores = output['instances'].scores[splashes]
88 | if len(scores) == 0:
89 | return None
90 | pred_masks = output['instances'].pred_masks[splashes]
91 | max_instance = torch.argmax(scores)
92 | pred_mask = np.array(pred_masks[max_instance].cpu())
93 | return pred_mask
94 |
95 | # Splash helper function that finds the splash instance with the highest percent confidence
96 | # and returns the
97 | def splash_area_percentage(output, pred_mask=None):
98 | if pred_mask is None:
99 | return
100 | # loops over pixels to get sum of splash pixels
101 | totalSum = 0
102 | for j in range(len(pred_mask)):
103 | totalSum += pred_mask[j].sum()
104 | # return percentage of image that is splash
105 | return totalSum/(len(pred_mask) * len(pred_mask[0]))
106 |
107 | def draw_two_coord(im, coord1, coord2, filename):
108 | print("hello, im in the drawing func")
109 | image = cv2.circle(im, (int(coord1[0]),int(coord1[1])), radius=5, color=(0, 0, 255), thickness=-1)
110 | image = cv2.circle(image, (int(coord2[0]),int(coord2[1])), radius=5, color=(0, 255, 0), thickness=-1)
111 | print(filename)
112 | if not cv2.imwrite(filename, image):
113 | print(filename)
114 | print("file failed to write")
115 |
116 | def draw_board_end_coord(im, coord):
117 | print("hello, im in the drawing func")
118 | image = cv2.circle(im, (int(coord[0]),int(coord[1])), radius=10, color=(0, 0, 255), thickness=-1)
119 | filename = os.path.join("./output/board_end/", d["file_name"][3:])
120 | print(filename)
121 | if not cv2.imwrite(filename, image):
122 | print(filename)
123 | print("file failed to write")
124 |
125 |
126 |
127 | #######################################################
128 |
129 | ################## DIVE ERROR MICROPROGRAMS ##################
130 |
131 | ## Feet apart error ##
132 | def applyFeetApartError(filepath, pose_pred=None, diver_detector=None, pose_model=None):
133 | if pose_pred is None and filepath != "":
134 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model)
135 | if pose_pred is not None:
136 | pose_pred = np.array(pose_pred)[0]
137 | average_knee = [np.mean((pose_pred[4][0], pose_pred[1][0])), np.mean((pose_pred[4][1], pose_pred[1][1]))]
138 | vector1 = [pose_pred[5][0] - average_knee[0], pose_pred[5][1] - average_knee[1]]
139 | vector2 = [pose_pred[0][0] - average_knee[0], pose_pred[0][1] - average_knee[1]]
140 | unit_vector_1 = vector1 / np.linalg.norm(vector1)
141 | unit_vector_2 = vector2 / np.linalg.norm(vector2)
142 | dot_product = np.dot(unit_vector_1, unit_vector_2)
143 | angle = math.degrees(np.arccos(dot_product))
144 | return angle
145 | else:
146 | return None
147 |
148 | ## Calculates hip bend for somersault tightness & twist straightness errors ##
149 | def applyPositionTightnessError(filepath, pose_pred=None, diver_detector=None, pose_model=None):
150 | if pose_pred is None and filepath != "":
151 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model)
152 | if pose_pred is not None:
153 | pose_pred = np.array(pose_pred)[0]
154 | vector1 = [pose_pred[7][0] - pose_pred[2][0], pose_pred[7][1] - pose_pred[2][1]]
155 | vector2 = [pose_pred[1][0] - pose_pred[2][0], pose_pred[1][1] - pose_pred[2][1]]
156 | unit_vector_1 = vector1 / np.linalg.norm(vector1)
157 | unit_vector_2 = vector2 / np.linalg.norm(vector2)
158 | dot_product = np.dot(unit_vector_1, unit_vector_2)
159 | angle = math.degrees(np.arccos(dot_product))
160 | return angle
161 | else:
162 | return None
163 |
164 | ## Distance from board error ##
165 | def calculate_distance_from_platform_for_one_frame(filepath, im=None, visualize=False, dive_folder_num="", platform_detector=None, pose_pred=None, diver_detector=None, pose_model=None, board_end_coord=None, board_side=None):
166 | if platform_detector is None:
167 | platform_detector = get_platform_detector()
168 | if pose_pred is None:
169 | diver_box, pose_pred = get_pose_estimation(filepath, image_bgr=im, diver_detector=diver_detector, pose_model=pose_model)
170 | if im is None and filepath != "":
171 | im = cv2.imread(filepath)
172 | if board_end_coord is None:
173 | outputs = platform_detector(im)
174 | board_end_coord = board_end(outputs, board_side=board_side)
175 | minDist = None
176 | if board_end_coord is not None and pose_pred is not None and len(board_end_coord) == 2:
177 | minDist = float('inf')
178 | for i in range(len(np.array(pose_pred)[0])):
179 | distance = math.dist(np.array(pose_pred)[0][i], np.array(board_end_coord))
180 | if distance < minDist:
181 | minDist = distance
182 | minJoint = i
183 | if visualize:
184 | file_name = filepath.split('/')[-1]
185 | folder = "./output/data/distance_from_board/{}".format(dive_folder_num)
186 | out_filename = os.path.join(folder, file_name)
187 | if not os.path.exists(folder):
188 | os.makedirs(folder)
189 | draw_two_coord(im, board_end_coord, np.array(pose_pred)[0][minJoint], filename=out_filename)
190 | return minDist
191 |
192 | ## Over-rotation error ##
193 | def over_rotation(filepath, pose_pred=None, diver_detector=None, pose_model=None):
194 | if pose_pred is None and filepath != "":
195 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model)
196 | if pose_pred is not None:
197 | pose_pred = np.array(pose_pred)[0]
198 | vector1 = [(pose_pred[0][0] - pose_pred[2][0]), 0-(pose_pred[0][1] - pose_pred[2][1])]
199 | vector2 = [-1, 0]
200 | unit_vector_1 = vector1 / np.linalg.norm(vector1)
201 | unit_vector_2 = vector2 / np.linalg.norm(vector2)
202 | dot_product = np.dot(unit_vector_1, unit_vector_2)
203 | angle = math.degrees(np.arccos(dot_product))
204 | return angle
205 | else:
206 | return None
207 |
208 | ## Splash size error ##
209 | def get_splash_from_one_frame(filepath, im=None, predictor=None, visualize=False, dive_folder_num=""):
210 | if predictor is None:
211 | predictor=get_splash_detector()
212 | if im is None:
213 | im = cv2.imread(filepath)
214 | outputs = predictor(im)
215 | pred_mask = get_splash_pred_mask(outputs)
216 | area = splash_area_percentage(outputs, pred_mask=pred_mask)
217 | if area is None:
218 | return None, None
219 | if visualize:
220 | pred_boxes = outputs['instances'].pred_boxes
221 | print("pred_boxes", pred_boxes)
222 | for box in pred_boxes:
223 | image = cv2.rectangle(im, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), color=(0, 0, 255), thickness=2)
224 | out_folder= "./output/data/splash/{}".format(dive_folder_num)
225 | if not os.path.exists(out_folder):
226 | os.makedirs(out_folder)
227 | filename = os.path.join(out_folder, filepath.split('/')[-1])
228 | if not cv2.imwrite(filename, image):
229 | print('no image written to', filename)
230 | break
231 |
232 | return area.tolist(), pred_mask
233 |
--------------------------------------------------------------------------------
/rule_based_programs/aqa_metaProgram_finediving.py:
--------------------------------------------------------------------------------
1 | """
2 | aqa_metaProgram_finediving.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | from rule_based_programs.microprograms.dive_error_functions import *
7 | from rule_based_programs.microprograms.temporal_segmentation_functions import *
8 | from rule_based_programs.microprograms.dive_recognition_functions import *
9 | from rule_based_programs.scoring_functions import get_scale_factor
10 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector
11 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model
12 | import pickle
13 | import os, math
14 | import numpy as np
15 | import cv2
16 | import argparse
17 |
18 | def getDiveInfo_from_diveNum(diveNum):
19 | handstand = (diveNum[0] == '6')
20 | expected_som = int(diveNum[2])
21 | if len(diveNum) == 5:
22 | expected_twists = int(diveNum[3])
23 | else:
24 | expected_twists = 0
25 | if diveNum[0] == '1' or diveNum[0] == '3' or diveNum[:2] == '51' or diveNum[:2] == '53' or diveNum[:2] == '61' or diveNum[:2] == '63':
26 | back_facing = False
27 | else:
28 | back_facing = True
29 | if diveNum[0] == '1' or diveNum[:2] == '51' or diveNum[:2] == '61':
30 | expected_direction = 'front'
31 | elif diveNum[0] == '2' or diveNum[:2] == '52' or diveNum[:2] == '62':
32 | expected_direction = 'back'
33 | elif diveNum[0] == '3' or diveNum[:2] == '53' or diveNum[:2] == '63':
34 | expected_direction = 'reverse'
35 | elif diveNum[0] == '4':
36 | expected_direction = 'inward'
37 | if diveNum[-1] == 'b':
38 | position = 'pike'
39 | elif diveNum[-1] == 'c':
40 | position = 'tuck'
41 | else:
42 | position = 'free'
43 | return handstand, expected_som, expected_twists, back_facing, expected_direction, position
44 |
45 | def aqa_metaprogram_finediving(first_folder, second_folder, diveNum, board_side=None, platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None):
46 | handstand, expected_som, expected_twists, back_facing, expected_direction, position = getDiveInfo_from_diveNum(diveNum)
47 | dive_data = {}
48 | takeoff = []
49 | twist = []
50 | som = []
51 | entry = []
52 | distance_from_board = []
53 | position_tightness = []
54 | feet_apart = []
55 | over_under_rotation = []
56 | splash = []
57 | pose_preds = []
58 | diver_boxes = []
59 | above_boards = []
60 | on_boards = []
61 | som_counts = []
62 | twist_counts = []
63 | board_end_coords = []
64 | plat_outputs = []
65 | board_sides = []
66 | splash_pred_masks = []
67 | above_board = True
68 | on_board = True
69 | if platform_detector is None:
70 | platform_detector = get_platform_detector()
71 | if splash_detector is None:
72 | splash_detector = get_splash_detector()
73 | if diver_detector is None:
74 | diver_detector = get_diver_detector()
75 | if pose_model is None:
76 | pose_model = get_pose_model()
77 | key = (first_folder, int(second_folder))
78 | dive_folder_num = "{}_{}".format(first_folder, second_folder)
79 | directory = './FineDiving/datasets/FINADiving_MTL_256s/{}/{}/'.format(first_folder, second_folder)
80 | file_names = os.listdir(directory)
81 |
82 | ## find board_side
83 | if board_side is None:
84 | for i in range(len(file_names)):
85 | filepath = directory + file_names[i]
86 | if file_names[i][-4:] != ".jpg":
87 | continue
88 | im = cv2.imread(filepath)
89 | plat_output = platform_detector(im)
90 | board_side = find_which_side_board_on(plat_output)
91 | if board_side is not None:
92 | board_sides.append(board_side)
93 | dive_data['board_sides'] = board_sides
94 | board_sides.sort()
95 | board_side = board_sides[len(board_sides)//2]
96 | dive_data['board_side'] = board_side
97 |
98 | prev_pred = None
99 | som_prev_pred = None
100 | half_som_count=0
101 | petal_count = 0
102 | in_petal = False
103 | for i in range(len(file_names)):
104 | filepath = directory + file_names[i]
105 | if file_names[i][-4:] != ".jpg":
106 | continue
107 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model)
108 | diver_boxes.append(diver_box)
109 | pose_preds.append(pose_pred)
110 |
111 | calculated_half_som_count, skip = som_counter(pose_pred, prev_pose_pred=som_prev_pred, half_som_count=half_som_count, handstand=handstand)
112 | if not skip:
113 | som_prev_pred = pose_pred
114 | calculated_petal_count, calculated_in_petal = twist_counter(pose_pred, prev_pose_pred=prev_pred, in_petal=in_petal, petal_count=petal_count)
115 | im = cv2.imread(filepath)
116 | plat_output = platform_detector(im)
117 | plat_outputs.append(plat_output)
118 | board_end_coord = board_end(plat_output, board_side=board_side)
119 | board_end_coords.append(board_end_coord)
120 | if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]):
121 | above_board=False
122 | if on_board and detect_on_board(board_end_coord, board_side, pose_pred, handstand) is not None and not detect_on_board(board_end_coord, board_side, pose_pred, handstand):
123 | on_board = False
124 | if above_board:
125 | above_boards.append(1)
126 | else:
127 | above_boards.append(0)
128 | if on_board:
129 | on_boards.append(1)
130 | else:
131 | on_boards.append(0)
132 | calculated_takeoff = takeoff_microprogram_one_frame(filepath, above_board=above_board, on_board=on_board, pose_pred=pose_pred)
133 | calculated_twist = twist_microprogram_one_frame(filepath, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, diver_detector=diver_detector, pose_model=pose_model)
134 | calculated_som = somersault_microprogram_one_frame(filepath, pose_pred=pose_pred, on_board=on_board, expected_som=expected_som, half_som_count=half_som_count, expected_twists=expected_twists, petal_count=petal_count, diver_detector=diver_detector, pose_model=pose_model)
135 | calculated_entry = entry_microprogram_one_frame(filepath, above_board=above_board, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, splash_detector=splash_detector, visualize=False, dive_folder_num=dive_folder_num)
136 | if calculated_som == 1:
137 | half_som_count = calculated_half_som_count
138 | elif calculated_twist == 1:
139 | half_som_count = calculated_half_som_count
140 | petal_count = calculated_petal_count
141 | in_petal = calculated_in_petal
142 | dist = calculate_distance_from_platform_for_one_frame(filepath, visualize=False, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/
143 | distance_from_board.append(dist)
144 | position_tightness.append(applyPositionTightnessError(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
145 | feet_apart.append(applyFeetApartError(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
146 | over_under_rotation.append(over_rotation(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
147 | splash_area, splash_pred_mask = get_splash_from_one_frame(filepath, predictor=splash_detector, visualize=False)
148 | splash.append(splash_area)
149 | splash_pred_masks.append(splash_pred_mask)
150 | takeoff.append(calculated_takeoff)
151 | twist.append(calculated_twist)
152 | som.append(calculated_som)
153 | entry.append(calculated_entry)
154 | som_counts.append(half_som_count)
155 | twist_counts.append(petal_count)
156 | prev_pred = pose_pred
157 |
158 | dive_data['pose_pred'] = pose_preds
159 | dive_data['takeoff'] = takeoff
160 | dive_data['twist'] = twist
161 | dive_data['som'] = som
162 | dive_data['entry'] = entry
163 | dive_data['distance_from_board'] = distance_from_board
164 | dive_data['position_tightness'] = position_tightness
165 | dive_data['feet_apart'] = feet_apart
166 | dive_data['over_under_rotation'] = over_under_rotation
167 | dive_data['splash'] = splash
168 | dive_data['above_boards'] = above_boards
169 | dive_data['on_boards'] = on_boards
170 | dive_data['som_counts'] = som_counts
171 | dive_data['twist_counts'] = twist_counts
172 | dive_data['board_end_coords'] = board_end_coords
173 | dive_data['diver_boxes'] = diver_boxes
174 | dive_data['splash_pred_masks'] = splash_pred_masks
175 | dive_data['board_side'] = board_side
176 | dive_data['is_handstand'] = handstand
177 | dive_data['direction'] = expected_direction
178 | dive_data['plat_outputs'] = plat_outputs
179 | return dive_data
180 |
181 | if __name__ == '__main__':
182 | # Set up command-line arguments
183 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.")
184 | new_parser.add_argument("FineDiving_key", type=str, nargs=2, help="key from FineDiving Dataset (e.g. 01 1)")
185 | meta_program_args = new_parser.parse_args()
186 |
187 | # Fine-grained annotations from FineDiving Dataset
188 | with open('FineDiving/Annotations/fine-grained_annotation_aqa.pkl', 'rb') as f:
189 | dive_annotation_data = pickle.load(f)
190 |
191 | key = tuple(meta_program_args.FineDiving_key)
192 | key = (key[0], int(key[1]))
193 | print(key)
194 | platform_detector = get_platform_detector()
195 | splash_detector = get_splash_detector()
196 | diver_detector = get_diver_detector()
197 | pose_model = get_pose_model()
198 | diveNum = dive_annotation_data[key][0]
199 | print(diveNum)
200 | dive_data = aqa_metaprogram_finediving(key[0], key[1], diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
201 |
202 | save_path = "./output/{}_{}.pkl".format(key[0], key[1])
203 | with open(save_path, 'wb') as f:
204 | print("saving data into " + save_path)
205 | pickle.dump(dive_data, f)
206 |
--------------------------------------------------------------------------------
/rule_based_programs/aqa_metaProgram.py:
--------------------------------------------------------------------------------
1 | """
2 | aqa_metaProgram.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | from rule_based_programs.microprograms.dive_error_functions import *
7 | from rule_based_programs.microprograms.temporal_segmentation_functions import *
8 | from rule_based_programs.microprograms.dive_recognition_functions import *
9 | from rule_based_programs.scoring_functions import get_scale_factor
10 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector
11 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model
12 | import gradio as gr
13 | import pickle
14 | import os, sys, math
15 | import numpy as np
16 | import cv2
17 | import argparse
18 |
19 | def extract_frames(video_path):
20 | cap = cv2.VideoCapture(video_path)
21 | # Check if the video file is opened successfully
22 | if not cap.isOpened():
23 | print("Error: Couldn't open video file.")
24 | exit()
25 | frame_skip = 1
26 | # a variable to keep track of the frame to be saved
27 | frame_count = 0
28 | frames = []
29 | i = 0
30 | while True:
31 | ret, frame = cap.read()
32 | if not ret:
33 | break
34 | if i > frame_skip - 1:
35 | frame_count += 1
36 | frame = cv2.resize(frame, (455, 256)) # resize takes argument (width, height)
37 | frames.append(frame)
38 | i = 0
39 | continue
40 | i += 1
41 | cap.release()
42 | return frames
43 |
44 | def getDiveInfo_from_diveNum(diveNum):
45 | handstand = (diveNum[0] == '6')
46 | expected_som = int(diveNum[2])
47 | if len(diveNum) == 5:
48 | expected_twists = int(diveNum[3])
49 | else:
50 | expected_twists = 0
51 | if diveNum[0] == '1' or diveNum[0] == '3' or diveNum[:2] == '51' or diveNum[:2] == '53' or diveNum[:2] == '61' or diveNum[:2] == '63':
52 | back_facing = False
53 | else:
54 | back_facing = True
55 | if diveNum[0] == '1' or diveNum[:2] == '51' or diveNum[:2] == '61':
56 | expected_direction = 'front'
57 | elif diveNum[0] == '2' or diveNum[:2] == '52' or diveNum[:2] == '62':
58 | expected_direction = 'back'
59 | elif diveNum[0] == '3' or diveNum[:2] == '53' or diveNum[:2] == '63':
60 | expected_direction = 'reverse'
61 | elif diveNum[0] == '4':
62 | expected_direction = 'inward'
63 | if diveNum[-1] == 'b':
64 | position = 'pike'
65 | elif diveNum[-1] == 'c':
66 | position = 'tuck'
67 | else:
68 | position = 'free'
69 | return handstand, expected_som, expected_twists, back_facing, expected_direction, position
70 |
71 | def getDiveInfo_from_symbols(frames, dive_data=None, platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None):
72 | print("Getting dive info from symbols...")
73 | if dive_data is None:
74 | print("somethings not getting passed in properly")
75 | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
76 |
77 | # get above_boards, on_boards, and position_tightness
78 | above_board = True
79 | on_board = True
80 | above_boards = []
81 | on_boards = []
82 | position_tightness = []
83 | distances = []
84 | prev_board_coord = None
85 | for i in range(len(dive_data['pose_pred'])):
86 | pose_pred = dive_data['pose_pred'][i]
87 | board_end_coord = dive_data['board_end_coords'][i]
88 | if board_end_coord is not None and prev_board_coord is not None:
89 | distances.append(math.dist(board_end_coord, prev_board_coord))
90 | if math.dist(board_end_coord, prev_board_coord) > 150:
91 | position_tightness.append(applyPositionTightnessError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
92 | if above_board:
93 | above_boards.append(1)
94 | else:
95 | above_boards.append(0)
96 | if on_board:
97 | on_boards.append(1)
98 | else:
99 | on_boards.append(0)
100 | continue
101 | if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]):
102 | above_board=False
103 | if on_board:
104 | handstand = is_handstand(dive_data)
105 | calculate_on_board = detect_on_board(board_end_coord, dive_data['board_side'], pose_pred, handstand)
106 | if calculate_on_board is not None and not calculate_on_board:
107 | on_board = False
108 | if above_board:
109 | above_boards.append(1)
110 | else:
111 | above_boards.append(0)
112 | if on_board:
113 | on_boards.append(1)
114 | else:
115 | on_boards.append(0)
116 | prev_board_coord = board_end_coord
117 | position_tightness.append(applyPositionTightnessError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
118 | dive_data['on_boards'] = on_boards
119 | dive_data['above_boards'] = above_boards
120 | dive_data['position_tightness'] = position_tightness
121 |
122 | ## handstand and som_count##
123 | expected_som, handstand = som_counter_full_dive(dive_data)
124 |
125 | ## twist_count
126 | expected_twists = twist_counter_full_dive(dive_data)
127 |
128 | ## direction: front, back, reverse, inward
129 | expected_direction = get_direction(dive_data)
130 |
131 | return handstand, expected_som, expected_twists, expected_direction, dive_data
132 |
133 |
134 | def abstractSymbols(frames, progress=gr.Progress(), platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None):
135 | print("Abstracting symbols...")
136 | splashes = []
137 | pose_preds = []
138 | board_sides = []
139 | plat_outputs = []
140 | diver_boxes = []
141 | splash_pred_masks = []
142 | if platform_detector is None:
143 | platform_detector = get_platform_detector()
144 | if splash_detector is None:
145 | splash_detector = get_splash_detector()
146 | if diver_detector is None:
147 | diver_detector = get_diver_detector()
148 | if pose_model is None:
149 | pose_model = get_pose_model()
150 | num_frames = len(frames)
151 | i = 0
152 | for frame in frames:
153 | progress(i/num_frames, desc="Abstracting Symbols")
154 | plat_output = platform_detector(frame)
155 | plat_outputs.append(plat_output)
156 | board_side = find_which_side_board_on(plat_output)
157 | if board_side is not None:
158 | board_sides.append(board_side)
159 | diver_box, pose_pred = get_pose_estimation(filepath="", image_bgr=frame, diver_detector=diver_detector, pose_model=pose_model)
160 | pose_preds.append(pose_pred)
161 | diver_boxes.append(diver_box)
162 | splash_area, splash_pred_mask = get_splash_from_one_frame(filepath="", im=frame, predictor=splash_detector, visualize=False)
163 | splash_pred_masks.append(splash_pred_mask)
164 | splashes.append(splash_area)
165 | i+=1
166 | dive_data = {}
167 | dive_data['plat_outputs'] = plat_outputs
168 | dive_data['pose_pred'] = pose_preds
169 | dive_data['splash'] = splashes
170 | dive_data['splash_pred_masks'] = splash_pred_masks
171 | dive_data['board_sides'] = board_sides
172 | board_sides.sort()
173 | board_side = board_sides[len(board_sides)//2]
174 | dive_data['board_side'] = board_side
175 | dive_data['diver_boxes'] = diver_boxes
176 |
177 | # get board_end_coords
178 | board_end_coords = []
179 | for plat_output in dive_data['plat_outputs']:
180 | board_end_coord = board_end(plat_output, board_side=dive_data['board_side'])
181 | board_end_coords.append(board_end_coord)
182 | dive_data['board_end_coords'] = board_end_coords
183 |
184 | return dive_data
185 |
186 | def aqa_metaprogram(frames, dive_data, progress=gr.Progress(), diveNum="", board_side=None, platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None):
187 | print("AQA Metaprogram...")
188 | if len(frames) != len(dive_data['pose_pred']):
189 | raise gr.Error("Abstract Symbols first!")
190 | if diveNum != "":
191 | dive_num_given = True
192 | handstand, expected_som, expected_twists, back_facing, expected_direction, position = getDiveInfo_from_diveNum(diveNum)
193 | else:
194 | dive_num_given = False
195 | handstand, expected_som, expected_twists, expected_direction, dive_data = getDiveInfo_from_symbols(frames, dive_data=dive_data, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
196 |
197 | if not dive_num_given:
198 | above_boards = dive_data['above_boards']
199 | on_boards = dive_data['on_boards']
200 | position_tightness = dive_data['position_tightness']
201 | board_end_coords = dive_data['board_end_coords']
202 | else:
203 | above_board = True
204 | on_board = True
205 | above_boards = []
206 | on_boards = []
207 | board_end_coords = []
208 | position_tightness = []
209 | splash = dive_data['splash']
210 | diver_boxes = dive_data['diver_boxes']
211 | board_side = dive_data['board_side']
212 | pose_preds = dive_data['pose_pred']
213 | takeoff = []
214 | twist = []
215 | som = []
216 | entry = []
217 | distance_from_board = []
218 | feet_apart = []
219 | over_under_rotation = []
220 | som_counts = []
221 | twist_counts = []
222 |
223 | if platform_detector is None:
224 | platform_detector = get_platform_detector()
225 | if splash_detector is None:
226 | splash_detector = get_splash_detector()
227 | if diver_detector is None:
228 | diver_detector = get_diver_detector()
229 | if pose_model is None:
230 | pose_model = get_pose_model()
231 |
232 | prev_pred = None
233 | som_prev_pred = None
234 | half_som_count=0
235 | petal_count = 0
236 | in_petal = False
237 | num_frames = len(frames)
238 | for i in range(num_frames):
239 | progress(i/num_frames, desc="Calculating Dive Errors")
240 | pose_pred = pose_preds[i]
241 | calculated_half_som_count, skip = som_counter(pose_pred, prev_pose_pred=som_prev_pred, half_som_count=half_som_count, handstand=handstand)
242 | if not skip:
243 | som_prev_pred = pose_pred
244 | calculated_petal_count, calculated_in_petal = twist_counter(pose_pred, prev_pose_pred=prev_pred, in_petal=in_petal, petal_count=petal_count)
245 | if dive_num_given:
246 | outputs = platform_detector(frames[i])
247 | board_end_coord = board_end(outputs, board_side=board_side)
248 | board_end_coords.append(board_end_coord)
249 | if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]):
250 | above_board=False
251 | if on_board and detect_on_board(board_end_coord, board_side, pose_pred, handstand) is not None and not detect_on_board(board_end_coord, board_side, pose_pred, handstand):
252 | on_board = False
253 | if above_board:
254 | above_boards.append(1)
255 | else:
256 | above_boards.append(0)
257 | if on_board:
258 | on_boards.append(1)
259 | else:
260 | on_boards.append(0)
261 | else:
262 | board_end_coord = board_end_coords[i]
263 | above_board = (above_boards[i] == 1)
264 | on_board = (on_boards[i] == 1)
265 | calculated_takeoff = takeoff_microprogram_one_frame(filepath="", above_board=above_board, on_board=on_board, pose_pred=pose_pred)
266 | calculated_twist = twist_microprogram_one_frame(filepath="", on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, diver_detector=diver_detector, pose_model=pose_model)
267 | calculated_som = somersault_microprogram_one_frame(filepath="", pose_pred=pose_pred, on_board=on_board, expected_som=expected_som, half_som_count=half_som_count, expected_twists=expected_twists, petal_count=petal_count, diver_detector=diver_detector, pose_model=pose_model)
268 | calculated_entry = entry_microprogram_one_frame(filepath="", frame=frames[i], above_board=above_board, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, splash_detector=splash_detector, visualize=False)
269 | if calculated_som == 1:
270 | half_som_count = calculated_half_som_count
271 | elif calculated_twist == 1:
272 | half_som_count = calculated_half_som_count
273 | petal_count = calculated_petal_count
274 | in_petal = calculated_in_petal
275 | # distance from board
276 | dist = calculate_distance_from_platform_for_one_frame(filepath="", im=frames[i], visualize=False, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/
277 | distance_from_board.append(dist)
278 | if dive_num_given:
279 | position_tightness.append(applyPositionTightnessError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
280 | feet_apart.append(applyFeetApartError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
281 | over_under_rotation.append(over_rotation(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model))
282 | takeoff.append(calculated_takeoff)
283 | twist.append(calculated_twist)
284 | som.append(calculated_som)
285 | entry.append(calculated_entry)
286 | som_counts.append(half_som_count)
287 | twist_counts.append(petal_count)
288 | prev_pred = pose_pred
289 |
290 | dive_data['takeoff'] = takeoff
291 | dive_data['twist'] = twist
292 | dive_data['som'] = som
293 | dive_data['entry'] = entry
294 | dive_data['distance_from_board'] = distance_from_board
295 | dive_data['position_tightness'] = position_tightness
296 | dive_data['feet_apart'] = feet_apart
297 | dive_data['over_under_rotation'] = over_under_rotation
298 | dive_data['above_boards'] = above_boards
299 | dive_data['on_boards'] = on_boards
300 | dive_data['som_counts'] = som_counts
301 | dive_data['twist_counts'] = twist_counts
302 | dive_data['board_end_coords'] = board_end_coords
303 | dive_data['is_handstand'] = handstand
304 | dive_data['direction'] = expected_direction
305 | return dive_data
306 |
307 | if __name__ == '__main__':
308 | # Set up command-line arguments
309 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.")
310 | new_parser.add_argument("video_path", type=str, help="Path to dive video (mp4 format).")
311 | meta_program_args = new_parser.parse_args()
312 |
313 | video_path = meta_program_args.video_path
314 | frames = extract_frames(video_path)
315 | platform_detector = get_platform_detector()
316 | splash_detector = get_splash_detector()
317 | diver_detector = get_diver_detector()
318 | pose_model = get_pose_model()
319 |
320 | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
321 | dive_data = aqa_metaprogram(frames, dive_data, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model)
322 |
323 | save_path = "./output/{}.pkl".format("".join(video_path.split('.')[:-1]))
324 | with open(save_path, 'wb') as f:
325 | print("saving data into " + save_path)
326 | pickle.dump(dive_data, f)
327 |
--------------------------------------------------------------------------------
/rule_based_programs/microprograms/dive_recognition_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | dive_recognition_functions.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | import pickle
7 | import numpy as np
8 | import os
9 | import math
10 | from scipy.signal import find_peaks
11 | from matplotlib import pyplot as plt
12 |
13 | def get_scale_factor(dive_data):
14 | # distance between thorax and pelvis
15 | distances = []
16 | for pose_pred in dive_data['pose_pred']:
17 | if pose_pred is not None:
18 | distances.append(math.dist(pose_pred[0][6], pose_pred[0][7]))
19 | distances.sort()
20 | return np.median(distances)
21 |
22 | def find_angle(vector1, vector2):
23 | unit_vector_1 = vector1 / np.linalg.norm(vector1)
24 | unit_vector_2 = vector2 / np.linalg.norm(vector2)
25 | dot_product = np.dot(unit_vector_1, unit_vector_2)
26 | angle = math.degrees(np.arccos(dot_product))
27 | return angle
28 |
29 | def is_back_facing(dive_data, board_side):
30 | directions = []
31 | for i in range(len(dive_data['pose_pred'])):
32 | pose_pred = dive_data['pose_pred'][i]
33 | if pose_pred is None or dive_data['above_boards'][i] == 0:
34 | continue
35 | pose_pred = pose_pred[0]
36 |
37 | ## left knee bend ###
38 | l_knee = pose_pred[4]
39 | l_ankle = pose_pred[5]
40 | l_hip = pose_pred[3]
41 | l_knee_ankle = [l_ankle[0] - l_knee[0], 0-(l_ankle[1] - l_knee[1])]
42 | l_knee_hip = [l_hip[0] - l_knee[0], 0-(l_hip[1] - l_knee[1])]
43 | l_direction = rotation_direction(l_knee_hip, l_knee_ankle)
44 |
45 | ## right knee bend ###
46 | r_knee = pose_pred[1]
47 | r_ankle = pose_pred[0]
48 | r_hip = pose_pred[2]
49 | r_knee_ankle = [r_ankle[0] - r_knee[0], 0-(r_ankle[1] - r_knee[1])]
50 | r_knee_hip = [r_hip[0] - r_knee[0], 0-(r_hip[1] - r_knee[1])]
51 | r_direction = rotation_direction(r_knee_hip, r_knee_ankle)
52 | if l_direction == r_direction and l_direction != 0 and board_side == 'left':
53 | # we're looking for more clockwise
54 | return l_direction < 0
55 | elif l_direction == r_direction and l_direction != 0:
56 | # we're looking for more counterclockwise
57 | return l_direction > 0
58 | return False
59 |
60 | def rotation_direction(vector1, vector2, threshold=0.4):
61 | # Calculate the determinant to determine rotation direction
62 | determinant = vector1[0] * vector2[1] - vector1[1] * vector2[0]
63 | mag1= np.linalg.norm(vector1)
64 | mag2= np.linalg.norm(vector2)
65 | norm_det = determinant/(mag1*mag2)
66 | if norm_det > threshold:
67 | # "counterclockwise"
68 | return 1
69 | elif norm_det < 0-threshold:
70 | # "clockwise"
71 | return -1
72 | else:
73 | # "not determinent"
74 | return 0
75 |
76 | # returns None if either pose_pred or board_end_coord is None
77 | # returns True if diver is on board, returns False if diver is off board
78 | def detect_on_board(board_end_coord, board_side, pose_pred, handstand):
79 | if pose_pred is None:
80 | return
81 | if board_end_coord is None:
82 | return
83 | if board_side == 'left':
84 | # if right of board end
85 | if np.array(pose_pred)[0][2][0] > int(board_end_coord[0]):
86 | if handstand:
87 | distance = math.dist(np.array(pose_pred)[0][15], board_end_coord) < math.dist(np.array(pose_pred)[0][14], np.array(pose_pred)[0][15]) * 1.5
88 | else:
89 | distance = math.dist(np.array(pose_pred)[0][5], board_end_coord) < math.dist(np.array(pose_pred)[0][5], np.array(pose_pred)[0][4]) * 1.5
90 |
91 | return distance
92 | # if left of board end
93 | else:
94 | return True
95 | else:
96 | # if right of board end
97 | if np.array(pose_pred)[0][2][0] > int(board_end_coord[0]):
98 | return True
99 | # if left of board end
100 | else:
101 | if handstand:
102 | distance = math.dist(np.array(pose_pred)[0][10], board_end_coord) < math.dist(np.array(pose_pred)[0][11], np.array(pose_pred)[0][10]) * 1.5
103 | else:
104 | distance = math.dist(np.array(pose_pred)[0][0], board_end_coord) < math.dist(np.array(pose_pred)[0][1], np.array(pose_pred)[0][0]) * 1.5
105 | return distance
106 |
107 | def find_position(dive_data):
108 | angles = []
109 | three_in_a_row = 0
110 | for i in range(1, len(dive_data['pose_pred'])):
111 | pose_pred = dive_data['pose_pred'][i]
112 | if pose_pred is None or dive_data['som'][i]==0:
113 | continue
114 | pose_pred = pose_pred[0]
115 | l_knee = pose_pred[4]
116 | l_ankle = pose_pred[5]
117 | l_hip = pose_pred[3]
118 | l_knee_ankle = [l_ankle[0] - l_knee[0], 0-(l_ankle[1] - l_knee[1])]
119 | l_knee_hip = [l_hip[0] - l_knee[0], 0-(l_hip[1] - l_knee[1])]
120 | angle = find_angle(l_knee_ankle, l_knee_hip)
121 | angles.append(angle)
122 | if angle < 70:
123 | three_in_a_row += 1
124 | if three_in_a_row >=3:
125 | return 'tuck'
126 | else:
127 | three_in_a_row =0
128 | if twist_counter_full_dive(dive_data) > 0 and som_counter_full_dive(dive_data)[0] < 5:
129 | return 'free'
130 | return 'pike'
131 |
132 |
133 | def distance_point_to_line_segment(px, py, x1, y1, x2, y2):
134 | # Calculate the squared distance from point (px, py) to the line segment [(x1, y1), (x2, y2)]
135 | def sqr_distance_point_to_segment():
136 | line_length_sq = (x2 - x1)**2 + (y2 - y1)**2
137 | if line_length_sq == 0:
138 | return (px - x1)**2 + (py - y1)**2
139 | t = max(0, min(1, ((px - x1) * (x2 - x1) + (py - y1) * (y2 - y1)) / line_length_sq))
140 | return ((px - (x1 + t * (x2 - x1)))**2 + (py - (y1 + t * (y2 - y1)))**2)
141 |
142 | # Calculate the closest point on the line segment to the given point (px, py)
143 | def closest_point_on_line_segment():
144 | line_length_sq = (x2 - x1)**2 + (y2 - y1)**2
145 | if line_length_sq == 0:
146 | return x1, y1
147 | t = max(0, min(1, ((px - x1) * (x2 - x1) + (py - y1) * (y2 - y1)) / line_length_sq))
148 | closest_x = x1 + t * (x2 - x1)
149 | closest_y = y1 + t * (y2 - y1)
150 | return closest_x, closest_y
151 |
152 | closest_point = closest_point_on_line_segment()
153 | distance = math.sqrt(sqr_distance_point_to_segment())
154 |
155 | return closest_point, distance
156 |
157 | def min_distance_from_line_to_circle(line_start, line_end, circle_center, circle_radius):
158 | closest_point, distance = distance_point_to_line_segment(circle_center[0], circle_center[1],
159 | line_start[0], line_start[1],
160 | line_end[0], line_end[1])
161 |
162 | min_distance = max(0, distance - circle_radius)
163 | return min_distance
164 |
165 | def twister(pose_pred, prev_pose_pred=None, in_petal=False, petal_count=0, outer=10, inner=9, valid=17, middle=0.5):
166 | if pose_pred is None:
167 | return petal_count, in_petal
168 | min_dist = 0
169 | pose_pred = pose_pred[0]
170 | vector1 = [pose_pred[2][0] - pose_pred[3][0], 0-(pose_pred[2][1] - pose_pred[3][1])]
171 | if prev_pose_pred is not None:
172 | prev_pose_pred = prev_pose_pred[0]
173 | prev_pose_pred = [prev_pose_pred[2][0] - prev_pose_pred[3][0], 0-(prev_pose_pred[2][1] - prev_pose_pred[3][1])]
174 | min_dist = min_distance_from_line_to_circle(prev_pose_pred, vector1, (0, 0), middle)
175 | if np.linalg.norm(vector1) > valid:
176 | return petal_count, in_petal
177 | if min_dist is not None and in_petal and np.linalg.norm(vector1) > outer and min_dist == 0:
178 | petal_count += 1
179 | elif not in_petal and np.linalg.norm(vector1) > outer:
180 | in_petal = True
181 | petal_count += 1
182 | elif in_petal and np.linalg.norm(vector1) < inner:
183 | in_petal = False
184 | return petal_count, in_petal
185 |
186 |
187 | def twist_counter(pose_pred, prev_pose_pred=None, in_petal=False, petal_count=0):
188 | valid = 17
189 | outer = 10
190 | inner = 9
191 | if pose_pred is None:
192 | return petal_count, in_petal
193 | min_dist = 0
194 | pose_pred = pose_pred[0]
195 | vector1 = [pose_pred[2][0] - pose_pred[3][0], 0-(pose_pred[2][1] - pose_pred[3][1])]
196 | if prev_pose_pred is not None:
197 | prev_pose_pred = prev_pose_pred[0]
198 | prev_pose_pred = [prev_pose_pred[2][0] - prev_pose_pred[3][0], 0-(prev_pose_pred[2][1] - prev_pose_pred[3][1])]
199 | min_dist = min_distance_from_line_to_circle(prev_pose_pred, vector1, (0, 0), 0.5)
200 | if np.linalg.norm(vector1) > valid:
201 | return petal_count, in_petal
202 | if min_dist is not None and in_petal and np.linalg.norm(vector1) > outer and min_dist == 0:
203 | petal_count += 1
204 | elif not in_petal and np.linalg.norm(vector1) > outer:
205 | in_petal = True
206 | elif in_petal and np.linalg.norm(vector1) < inner:
207 | in_petal = False
208 | petal_count += 1
209 | return petal_count, in_petal
210 |
211 |
212 | def twist_counter_full_dive(dive_data, visualize=False):
213 | dist_hip = []
214 | prev_pose_pred = None
215 | in_petal=False
216 | petal_count=0
217 | scale = get_scale_factor(dive_data)
218 | valid = scale / 1.5
219 | outer = scale / 3.2
220 | inner = scale / 3.4
221 | middle = 0.5
222 | next_next_pose_pred = dive_data['pose_pred'][4]
223 | for i in range(len(dive_data['pose_pred'])):
224 | pose_pred = dive_data['pose_pred'][i]
225 | if i < len(dive_data['pose_pred']) - 1:
226 | next_pose_pred = dive_data['pose_pred'][i + 1]
227 | if i < len(dive_data['pose_pred']) - 4 and next_next_pose_pred is not None:
228 | next_next_pose_pred = dive_data['pose_pred'][i + 4]
229 | if pose_pred is None or dive_data['on_boards'][i] == 1 or dive_data['position_tightness'][i] <= 80 or next_next_pose_pred is None:
230 | continue
231 | petal_count, in_petal = twister(pose_pred, prev_pose_pred=prev_pose_pred, in_petal=in_petal, petal_count=petal_count, outer=outer, inner=inner, middle=middle, valid=valid)
232 | prev_pose_pred = pose_pred
233 | if visualize:
234 | pose_pred = pose_pred[0]
235 | dist_hip.append([pose_pred[2][0] - pose_pred[3][0], 0-(pose_pred[2][1] - pose_pred[3][1])])
236 | if visualize:
237 | dist_hip = np.array(dist_hip)
238 | plt.plot(dist_hip[:, 0], dist_hip[:, 1], label="right-to-left hip")
239 | circle1 = plt.Circle((0, 0), outer, fill=False)
240 | plt.gca().add_patch(circle1)
241 | circle2 = plt.Circle((0, 0), inner, fill=False)
242 | plt.gca().add_patch(circle2)
243 | circle3 = plt.Circle((0, 0), valid, fill=False)
244 | plt.gca().add_patch(circle3)
245 | plt.legend()
246 | plt.show()
247 | return petal_count
248 |
249 | def rotation_direction_som(vector1, vector2, threshold=0.4):
250 | # Calculate the determinant to determine rotation direction
251 | determinant = vector1[0] * vector2[1] - vector1[1] * vector2[0]
252 | mag1= np.linalg.norm(vector1)
253 | mag2= np.linalg.norm(vector2)
254 | norm_det = determinant/(mag1*mag2)
255 | theta = np.arcsin(norm_det)
256 | return math.degrees(theta)
257 | def is_handstand(dive_data):
258 | first_frame_pose_pred = dive_data['pose_pred'][0]
259 | handstand = False
260 | if first_frame_pose_pred[0][6][1] < first_frame_pose_pred[0][7][1]:
261 | handstand = True
262 | return handstand
263 |
264 |
265 | def som_counter(pose_pred=None, prev_pose_pred=None, half_som_count=0, handstand=False):
266 | if pose_pred is None:
267 | return half_som_count, True
268 | pose_pred = pose_pred[0]
269 | vector1 = [pose_pred[7][0] - pose_pred[6][0], 0-(pose_pred[7][1] - pose_pred[6][1])] # flip y axis
270 | if (not handstand and half_som_count % 2 == 0) or (handstand and half_som_count %2 == 1):
271 | vector2 = [0, -1]
272 | else:
273 | vector2 = [0, 1]
274 | unit_vector_1 = vector1 / np.linalg.norm(vector1)
275 | unit_vector_2 = vector2 / np.linalg.norm(vector2)
276 | dot_product = np.dot(unit_vector_1, unit_vector_2)
277 | current_angle = math.degrees(np.arccos(dot_product))
278 | if prev_pose_pred is not None:
279 | prev_pose_pred = prev_pose_pred[0]
280 | prev_vector = [prev_pose_pred[7][0] - prev_pose_pred[6][0], 0-(prev_pose_pred[7][1] - prev_pose_pred[6][1])] # flip y axis
281 | prev_unit_vector = prev_vector / np.linalg.norm(prev_vector)
282 | prev_angle_diff = math.degrees(np.arccos(np.dot(unit_vector_1, prev_unit_vector)))
283 | if prev_angle_diff > 115:
284 | return half_som_count, True
285 | if current_angle <= 80:
286 | half_som_count += 1
287 | return half_som_count, False
288 |
289 |
290 | def som_counter_full_dive(dive_data, visualize=False):
291 | half_som_count = 0
292 | dist_body = []
293 | handstand = is_handstand(dive_data)
294 | next_next_pose_pred = dive_data['pose_pred'][2]
295 | prev = None
296 | for i in range(len(dive_data['pose_pred'])):
297 | pose_pred = dive_data['pose_pred'][i]
298 | if i < len(dive_data['pose_pred']) - 2 and next_next_pose_pred is not None:
299 | next_next_pose_pred = dive_data['pose_pred'][i + 2]
300 | if pose_pred is None or next_next_pose_pred is None or dive_data['on_boards'][i] == 1:
301 | continue
302 | pose_pred = pose_pred[0]
303 | vector1 = [pose_pred[7][0] - pose_pred[6][0], 0-(pose_pred[7][1] - pose_pred[6][1])]
304 | if (not handstand and half_som_count % 2 == 0) or (handstand and half_som_count % 2 == 1):
305 | vector2 = [0, -1]
306 | else:
307 | vector2 = [0, 1]
308 | sensitivity = 115
309 | if prev is not None and find_angle(vector1, prev) > sensitivity:
310 | continue
311 | is_clockwise = is_rotating_clockwise(dive_data)
312 | if prev is not None and ((is_clockwise and rotation_direction_som(vector1, prev)<0) or (not is_clockwise and rotation_direction_som(vector1, prev)>0)):
313 | continue
314 | angle = find_angle(vector1, vector2)
315 | if angle <= 75:
316 | half_som_count += 1
317 | if visualize:
318 | dist_body.append([pose_pred[7][0] - pose_pred[6][0], 0-(pose_pred[7][1] - pose_pred[6][1])])
319 | prev = vector1
320 | if visualize:
321 | dist_body = np.array(dist_body)
322 | plt.plot(dist_body[:, 0], dist_body[:, 1], label="pelvis-to-thorax")
323 | plt.xlabel("x-coord")
324 | plt.ylabel("y-coord")
325 | plt.legend()
326 | plt.show()
327 | return half_som_count, handstand
328 |
329 | def getDiveInfo(diveNum):
330 | handstand = (diveNum[0] == '6')
331 | expected_som = int(diveNum[2])
332 | if len(diveNum) == 5:
333 | expected_twists = int(diveNum[3])
334 | else:
335 | expected_twists = 0
336 | if diveNum[0] == '1' or diveNum[0] == '3' or diveNum[:2] == '51' or diveNum[:2] == '53' or diveNum[:2] == '61' or diveNum[:2] == '63':
337 | back_facing = False
338 | else:
339 | back_facing = True
340 | if diveNum[0] == '1' or diveNum[:2] == '51' or diveNum[:2] == '61':
341 | expected_direction = 'front'
342 | elif diveNum[0] == '2' or diveNum[:2] == '52' or diveNum[:2] == '62':
343 | expected_direction = 'back'
344 | elif diveNum[0] == '3' or diveNum[:2] == '53' or diveNum[:2] == '63':
345 | expected_direction = 'reverse'
346 | elif diveNum[0] == '4':
347 | expected_direction = 'inward'
348 | if diveNum[-1] == 'b':
349 | position = 'pike'
350 | elif diveNum[-1] == 'c':
351 | position = 'tuck'
352 | else:
353 | position = 'free'
354 | return handstand, expected_som, expected_twists, back_facing, expected_direction, position
355 |
356 | def get_direction(dive_data):
357 | clockwise = is_rotating_clockwise(dive_data)
358 | board_side = dive_data['board_side']
359 | if board_side == "right":
360 | back_facing = is_back_facing(dive_data, 'right')
361 | if back_facing and clockwise:
362 | direction = 'inward'
363 | elif back_facing and not clockwise:
364 | direction = 'back'
365 | elif not back_facing and clockwise:
366 | direction = 'reverse'
367 | elif not back_facing and not clockwise:
368 | direction = 'front'
369 | else:
370 | back_facing = is_back_facing(dive_data, 'left')
371 | if back_facing and clockwise:
372 | direction = 'back'
373 | elif back_facing and not clockwise:
374 | direction = 'inward'
375 | elif not back_facing and clockwise:
376 | direction = 'front'
377 | elif not back_facing and not clockwise:
378 | direction = 'reverse'
379 | return direction
380 |
381 | def is_rotating_clockwise(dive_data):
382 | directions = []
383 | for i in range(1, len(dive_data['pose_pred'])):
384 | if dive_data['pose_pred'][i] is None or dive_data['pose_pred'][i-1] is None:
385 | continue
386 | if dive_data['on_boards'][i] == 0:
387 | prev_pose_pred_hip = dive_data['pose_pred'][i-1][0][3]
388 | curr_pose_pred_hip = dive_data['pose_pred'][i][0][3]
389 | prev_pose_pred_knee = dive_data['pose_pred'][i-1][0][4]
390 | curr_pose_pred_knee = dive_data['pose_pred'][i][0][4]
391 | prev_hip_knee = [prev_pose_pred_knee[0] - prev_pose_pred_hip[0], 0-(prev_pose_pred_knee[1] - prev_pose_pred_hip[1])]
392 | curr_hip_knee = [curr_pose_pred_knee[0] - curr_pose_pred_hip[0], 0-(curr_pose_pred_knee[1] - curr_pose_pred_hip[1])]
393 | direction = rotation_direction(prev_hip_knee, curr_hip_knee, threshold=0)
394 | directions.append(direction)
395 | return np.sum(directions) < 0
396 |
--------------------------------------------------------------------------------
/rule_based_programs/scoring_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | scoring_functions.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | import pickle
7 | import numpy as np
8 | import os
9 | import math
10 | from scipy.signal import find_peaks
11 | from rule_based_programs.microprograms.dive_recognition_functions import *
12 |
13 | ### All functions (excluding helper functions) take "dive_data" as input,
14 | ### which is the dictionary with all the information outputted by the AQA metaprogram
15 |
16 | ############## HELPER FUNCTIONS ######################################
17 | def rotation_direction(vector1, vector2, threshold=0.4):
18 | # Calculate the determinant to determine rotation direction
19 | determinant = vector1[0] * vector2[1] - vector1[1] * vector2[0]
20 | mag1= np.linalg.norm(vector1)
21 | mag2= np.linalg.norm(vector2)
22 | norm_det = determinant/(mag1*mag2)
23 | if norm_det > threshold:
24 | # "counterclockwise"
25 | return 1
26 | elif norm_det < 0-threshold:
27 | # "clockwise"
28 | return -1
29 | else:
30 | # "not determinent"
31 | return 0
32 |
33 | def find_angle(vector1, vector2):
34 | unit_vector_1 = vector1 / np.linalg.norm(vector1)
35 | unit_vector_2 = vector2 / np.linalg.norm(vector2)
36 | dot_product = np.dot(unit_vector_1, unit_vector_2)
37 | angle = math.degrees(np.arccos(dot_product))
38 | return angle
39 |
40 | #################################################################
41 |
42 | def height_off_board_score(dive_data):
43 | above_board_indices = [i for i in range(0, len(dive_data['distance_from_board'])) if dive_data['above_boards'][i]==1]
44 | takeoff_indices = [i for i in range(0, len(dive_data['takeoff'])) if dive_data['takeoff'][i]==1]
45 | final_indices = []
46 | prev_board_end_coord = None
47 | for i in range(len(above_board_indices)):
48 | board_end_coord = dive_data['board_end_coords'][i]
49 | if board_end_coord is not None and board_end_coord[1] < 30:
50 | continue
51 | if board_end_coord is not None and prev_board_end_coord is not None and math.dist(board_end_coord, prev_board_end_coord) > 150:
52 | continue
53 | if above_board_indices[i] not in takeoff_indices:
54 | final_indices.append(above_board_indices[i])
55 | prev_board_end_coord = board_end_coord
56 |
57 | heights = []
58 | for i in range(len(final_indices)):
59 | pose_pred = dive_data['pose_pred'][final_indices[i]]
60 | board_end_coord = dive_data['board_end_coords'][final_indices[i]]
61 | if pose_pred is None or board_end_coord is None:
62 | continue
63 | pose_pred = pose_pred[0]
64 | min_height = float('inf')
65 | for j in range(len(pose_pred)):
66 | if board_end_coord[1] - pose_pred[j][1]< min_height:
67 | min_height = board_end_coord[1] - pose_pred[j][1]
68 | if min_height < 0:
69 | min_height = 0
70 | heights.append(min_height)
71 | if len(heights) == 0:
72 | return None, None
73 | max_scaled_height = max(heights) / get_scale_factor(dive_data)
74 | return max_scaled_height, final_indices[np.argmax(heights)]
75 |
76 | def distance_from_board_score(dive_data):
77 | above_board_indices = [i for i in range(0, len(dive_data['distance_from_board'])) if dive_data['above_boards'][i]==1]
78 | takeoff_indices = [i for i in range(0, len(dive_data['takeoff'])) if dive_data['takeoff'][i]==1]
79 | final_indices = []
80 | for i in range(len(above_board_indices)):
81 | if above_board_indices[i] not in takeoff_indices:
82 | final_indices.append(above_board_indices[i])
83 | dists = np.array(dive_data['distance_from_board'])[final_indices]
84 | for i in range(len(dists)):
85 | if dists[i] is None:
86 | dists[i] = float('inf')
87 | min_scaled_dist = np.min(dists) / get_scale_factor(dive_data)
88 | too_close_threshold = 0.25
89 | if 'diveNum' in dive_data:
90 | if dive_data['diveNum'][0] == '4':
91 | too_far_threshold = 1.1
92 | if dive_data['diveNum'][0] == '1':
93 | too_far_threshold = 1.6
94 | if dive_data['diveNum'][0] == '2':
95 | too_far_threshold = 1.8
96 | if dive_data['diveNum'][0] == '3':
97 | too_far_threshold = 1.6
98 | if dive_data['diveNum'][0] == '5':
99 | too_far_threshold = 1.5
100 | if dive_data['diveNum'][0] == '6':
101 | too_far_threshold = 1.1
102 | else:
103 | too_far_threshold = 1.8
104 | # good distance
105 | if min_scaled_dist < too_far_threshold and min_scaled_dist > too_close_threshold:
106 | return 0, min_scaled_dist, final_indices[np.argmin(dists)]
107 | # too far
108 | if min_scaled_dist >= too_far_threshold:
109 | return 1, min_scaled_dist, final_indices[np.argmin(dists)]
110 | # too close
111 | if min_scaled_dist <= too_close_threshold:
112 | return -1, min_scaled_dist, final_indices[np.argmin(dists)]
113 | return min_scaled_dist
114 |
115 | def knee_bend_score(dive_data):
116 | if find_position(dive_data) == 'tuck':
117 | return None, None
118 | knee_bends = []
119 | for i in range(len(dive_data['pose_pred'])):
120 | if dive_data['som'][i] == 0:
121 | continue
122 | pose_pred = dive_data['pose_pred'][i]
123 | if pose_pred is None:
124 | continue
125 | pose_pred = pose_pred[0]
126 | knee_to_ankle = [pose_pred[1][0] - pose_pred[0][0], 0-(pose_pred[1][1]-pose_pred[0][1])]
127 | knee_to_hip = [pose_pred[1][0] - pose_pred[2][0], 0-(pose_pred[1][1]-pose_pred[2][1])]
128 | knee_bend = find_angle(knee_to_ankle, knee_to_hip)
129 | knee_bends.append(knee_bend)
130 | if len(knee_bends) == 0:
131 | return None, None
132 | som_indices = [i for i in range(0, len(dive_data['som'])) if dive_data['som'][i]==1]
133 | som_avg_knee_bend = np.mean(knee_bends)
134 | return 180 - som_avg_knee_bend, som_indices
135 |
136 | def position_tightness_score(dive_data):
137 | som_indices = [i for i in range(0, len(dive_data['som'])) if dive_data['som'][i]==1]
138 | twist_indices = [i for i in range(0, len(dive_data['som'])) if dive_data['twist'][i]==1]
139 | som_tightness = np.array(dive_data['position_tightness'])[som_indices]
140 | twist_tightness = 180 - np.array(dive_data['position_tightness'])[twist_indices]
141 |
142 | # Compute the area using the composite trapezoidal rule.
143 | som_tightness = np.array(list(filter(lambda item: item is not None and item < 90, som_tightness)))
144 | twist_tightness = np.array(list(filter(lambda item: item is not None and item < 90, twist_tightness)))
145 | if len(som_indices) == 0:
146 | som_avg = None
147 | else:
148 | som_avg = np.mean(som_tightness)
149 | if len(twist_indices)==0:
150 | return som_avg, None, som_indices, twist_indices
151 | twist_avg = np.mean(twist_tightness)
152 | if som_avg is not None:
153 | som_avg -= 15
154 | return som_avg, twist_avg, som_indices, twist_indices
155 |
156 | def is_rotating_clockwise(dive_data):
157 | directions = []
158 | for i in range(1, len(dive_data['pose_pred'])):
159 | if dive_data['pose_pred'][i] is None or dive_data['pose_pred'][i-1] is None:
160 | continue
161 | if dive_data['on_boards'][i] == 0:
162 | prev_pose_pred_hip = dive_data['pose_pred'][i-1][0][3]
163 | curr_pose_pred_hip = dive_data['pose_pred'][i][0][3]
164 | prev_pose_pred_knee = dive_data['pose_pred'][i-1][0][4]
165 | curr_pose_pred_knee = dive_data['pose_pred'][i][0][4]
166 | prev_hip_knee = [prev_pose_pred_knee[0] - prev_pose_pred_hip[0], 0-(prev_pose_pred_knee[1] - prev_pose_pred_hip[1])]
167 | curr_hip_knee = [curr_pose_pred_knee[0] - curr_pose_pred_hip[0], 0-(curr_pose_pred_knee[1] - curr_pose_pred_hip[1])]
168 | direction = rotation_direction(prev_hip_knee, curr_hip_knee, threshold=0)
169 | directions.append(direction)
170 | return np.sum(directions) < 0
171 |
172 | def over_under_rotation_score(dive_data):
173 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1]
174 | over_under_rotation_error = np.array(dive_data['over_under_rotation'])[entry_indices]
175 | splashes = np.array(dive_data['splash'])[entry_indices]
176 | for i in range(len(over_under_rotation_error) - 1, -1, -1):
177 | if over_under_rotation_error[i] is None:
178 | continue
179 | else:
180 | # gets the second to last pose (assuming the last pose has incorrect pose estimation)
181 | index = i-2
182 | if index < 0:
183 | index = 0
184 | total_index = entry_indices[index]
185 | if splashes[index] is None and over_under_rotation_error[index] is not None:
186 | pose_pred = dive_data['pose_pred'][total_index][0]
187 | thorax_pelvis_vector = [pose_pred[1][0] - pose_pred[7][0], 0-(pose_pred[1][1]-pose_pred[7][1])]
188 | prev_pose_pred = dive_data['pose_pred'][total_index - 1]
189 | if prev_pose_pred is not None:
190 | prev_pose_pred = prev_pose_pred[0]
191 | prev_thorax_pelvis_vector = [prev_pose_pred[1][0] - prev_pose_pred[7][0], 0-(prev_pose_pred[1][1]-prev_pose_pred[7][1])]
192 | rotation_speed = find_angle(thorax_pelvis_vector, prev_thorax_pelvis_vector)
193 | else:
194 | rotation_speed = 10
195 | vector2 = [0, 1]
196 | if is_rotating_clockwise(dive_data):
197 | # if under-rotated
198 | if thorax_pelvis_vector[0] < 0:
199 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) - rotation_speed
200 |
201 | else:
202 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) + rotation_speed
203 | else:
204 | # if over-rotated
205 | if thorax_pelvis_vector[0] < 0:
206 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) + rotation_speed
207 | else:
208 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) - rotation_speed
209 | return np.abs(avg_leg_torso), entry_indices[index]
210 | break
211 |
212 | def straightness_during_entry_score(dive_data):
213 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1]
214 | straightness_during_entry = np.array(dive_data['position_tightness'])[entry_indices]
215 | over_under_rotation = over_under_rotation_score(dive_data)
216 | if over_under_rotation is not None:
217 | frame = over_under_rotation[1]
218 | index = entry_indices.index(frame) - 1
219 | if index < 0:
220 | index = 0
221 | return 180-straightness_during_entry[index], [frame-1, frame, frame + 1]
222 | splashes = np.array(dive_data['splash'])[entry_indices]
223 | for i in range(len(straightness_during_entry) - 1, -1, -1):
224 | if i > 0 and (straightness_during_entry[i] is None or splashes[i] is not None):
225 | continue
226 | else:
227 | # gets the second to last pose (assuming the last pose has incorrect pose estimation)
228 | if straightness_during_entry[i] is not None:
229 | if 180-straightness_during_entry[i] > 130:
230 | continue
231 | return 180-straightness_during_entry[i], entry_indices[i-1:i+2]
232 | break
233 |
234 | def splash_score(dive_data):
235 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1]
236 | if len(entry_indices) == 0:
237 | return None
238 | splash_indices=[i for i in range(0, len(dive_data['splash'])) if dive_data['splash'][i] is not None]
239 | splashes = np.array(dive_data['splash'])[entry_indices]
240 | for i in range(len(splashes)):
241 | if splashes[i] is None:
242 | splashes[i] = 0
243 | splashes = splashes / get_scale_factor(dive_data)**2
244 | # area under curve
245 | area = np.trapz(splashes, dx=5)
246 | return area, entry_indices[np.argmax(splashes)], splash_indices
247 |
248 | # feet apart
249 | def feet_apart_score(dive_data):
250 | takeoff_indices = [i for i in range(0, len(dive_data['takeoff'])) if dive_data['takeoff'][i]==1]
251 | non_takeoff_indices = [i for i in range(len(dive_data['takeoff'])) if (i not in takeoff_indices and dive_data['splash'][i] is None)]
252 | feet_apart_error = np.array(dive_data['feet_apart'])[non_takeoff_indices]
253 | for i in range(len(feet_apart_error)):
254 | if feet_apart_error[i] is None or math.isnan(feet_apart_error[i]):
255 | feet_apart_error[i] = 0
256 | peaks, _ = find_peaks(feet_apart_error, height=5)
257 | if len(peaks) >= 1:
258 | peak_indices = np.array(non_takeoff_indices)[peaks]
259 | else:
260 | peak_indices = []
261 | area = np.mean(feet_apart_error)
262 |
263 | return area, peak_indices
264 |
265 | def find_position(dive_data):
266 | angles = []
267 | three_in_a_row = 0
268 | for i in range(1, len(dive_data['pose_pred'])):
269 | pose_pred = dive_data['pose_pred'][i]
270 | if pose_pred is None or dive_data['som'][i]==0:
271 | continue
272 | pose_pred = pose_pred[0]
273 | l_knee = pose_pred[4]
274 | l_ankle = pose_pred[5]
275 | l_hip = pose_pred[3]
276 | l_knee_ankle = [l_ankle[0] - l_knee[0], 0-(l_ankle[1] - l_knee[1])]
277 | l_knee_hip = [l_hip[0] - l_knee[0], 0-(l_hip[1] - l_knee[1])]
278 | angle = find_angle(l_knee_ankle, l_knee_hip)
279 | angles.append(angle)
280 | # print(angle)
281 | if angle < 70:
282 | three_in_a_row += 1
283 | if three_in_a_row >=3:
284 | return 'tuck'
285 | else:
286 | three_in_a_row =0
287 | return 'pike'
288 |
289 | def get_position_from_diveNum(dive_data):
290 | diveNum = dive_data['diveNum']
291 | position_code = diveNum[-1]
292 | if position_code == 'a':
293 | return "straight"
294 | elif position_code == 'b':
295 | return "pike"
296 | elif position_code == 'c':
297 | return "tuck"
298 | elif position_code == 'd':
299 | return "free"
300 | else:
301 | return None
302 |
303 | def get_all_report_scores(dive_data):
304 | with open('distribution_data.pkl', 'rb') as f:
305 | distribution_data = pickle.load(f)
306 | ## handstand and som_count##
307 | expected_som, handstand = som_counter_full_dive(dive_data)
308 | ## twist_count
309 | expected_twists = twist_counter_full_dive(dive_data)
310 | ## direction: front, back, reverse, inward
311 | expected_direction = get_direction(dive_data)
312 | dive_data['is_handstand'] = handstand
313 | dive_data['direction'] = expected_direction
314 |
315 | intermediate_scores = {}
316 | all_percentiles = []
317 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1]
318 |
319 | ### height off board ###
320 | if dive_data['is_handstand']:
321 | error_scores = distribution_data['armstand_height_off_board_scores']
322 | elif expected_twists >0:
323 | error_scores = distribution_data['twist_height_off_board_scores']
324 | elif dive_data['direction']=='front':
325 | error_scores = distribution_data['front_height_off_board_scores']
326 | elif dive_data['direction']=='back':
327 | error_scores = distribution_data['back_height_off_board_scores']
328 | elif dive_data['direction']=='reverse':
329 | error_scores = distribution_data['reverse_height_off_board_scores']
330 | elif dive_data['direction']=='inward':
331 | error_scores = distribution_data['inward_height_off_board_scores']
332 | error_scores = list(filter(lambda item: item is not None, error_scores))
333 | intermediate_scores['height_off_board'] = {}
334 | if dive_data['is_handstand']:
335 | intermediate_scores['height_off_board']['raw_score'] = None
336 | intermediate_scores['height_off_board']['frame_index'] = None
337 | else:
338 | intermediate_scores['height_off_board']['raw_score'] = height_off_board_score(dive_data)[0]
339 | intermediate_scores['height_off_board']['frame_index'] = height_off_board_score(dive_data)[1]
340 | err = intermediate_scores['height_off_board']['raw_score']
341 | if err is not None:
342 | temp = error_scores
343 | temp.append(err)
344 | temp.sort()
345 | intermediate_scores['height_off_board']['percentile'] = temp.index(err)/len(temp)
346 | all_percentiles.append(temp.index(err)/len(temp))
347 | else:
348 | intermediate_scores['height_off_board']['percentile'] = None
349 |
350 | ## distance from board ####
351 | error_scores = distribution_data['distance_from_board_scores']
352 | error_scores = list(filter(lambda item: item is not None, error_scores))
353 | intermediate_scores['distance_from_board'] = {}
354 | intermediate_scores['distance_from_board']['raw_score'] = distance_from_board_score(dive_data)[1]
355 | intermediate_scores['distance_from_board']['frame_index'] = distance_from_board_score(dive_data)[2]
356 | err = distance_from_board_score(dive_data)[0]
357 | if err is not None:
358 | if err == 1:
359 | intermediate_scores['distance_from_board']['percentile'] = "safe, but too far from"
360 | intermediate_scores['distance_from_board']['score'] = 0.5
361 |
362 | elif err == 0:
363 | intermediate_scores['distance_from_board']['percentile'] = "a good distance from"
364 | intermediate_scores['distance_from_board']['score'] = 1
365 | else:
366 | intermediate_scores['distance_from_board']['percentile'] = "too close to"
367 | intermediate_scores['distance_from_board']['score'] = 0
368 | all_percentiles.append(intermediate_scores['distance_from_board']['score'])
369 | else:
370 | intermediate_scores['distance_from_board']['percentile'] = None
371 | intermediate_scores['distance_from_board']['score'] = None
372 |
373 | ### feet_apart_scores ###
374 | error_scores = distribution_data['feet_apart_scores']
375 | error_scores = list(filter(lambda item: item is not None, error_scores))
376 | intermediate_scores['feet_apart'] = {}
377 | intermediate_scores['feet_apart']['raw_score'] = feet_apart_score(dive_data)[0]
378 | intermediate_scores['feet_apart']['peaks'] = feet_apart_score(dive_data)[1]
379 | err = intermediate_scores['feet_apart']['raw_score']
380 | if err is not None:
381 | temp = error_scores
382 | temp.append(err)
383 | temp.sort()
384 | intermediate_scores['feet_apart']['percentile'] = 1-temp.index(err)/len(temp)
385 | all_percentiles.append(1-temp.index(err)/len(temp))
386 | else:
387 | intermediate_scores['feet_apart']['percentile'] = None
388 |
389 | ### knee_bend_scores ###
390 | error_scores = distribution_data['knee_bend_scores']
391 | error_scores = list(filter(lambda item: item is not None, error_scores))
392 | intermediate_scores['knee_bend'] = {}
393 | intermediate_scores['knee_bend']['raw_score'] = knee_bend_score(dive_data)[0]
394 | intermediate_scores['knee_bend']['frame_indices'] = knee_bend_score(dive_data)[1]
395 | err = intermediate_scores['knee_bend']['raw_score']
396 | if err is not None:
397 | temp = error_scores
398 | temp.append(err)
399 | temp.sort()
400 | intermediate_scores['knee_bend']['percentile'] = 1-temp.index(err)/len(temp)
401 | all_percentiles.append(1-temp.index(err)/len(temp))
402 | else:
403 | intermediate_scores['knee_bend']['percentile'] = None
404 |
405 |
406 | ### som_position_tightness_scores ###
407 | error_scores = distribution_data['som_position_tightness_scores']
408 | error_scores = list(filter(lambda item: item is not None, error_scores))
409 | intermediate_scores['som_position_tightness'] = {}
410 | position = find_position(dive_data)
411 | if position == 'tuck':
412 | intermediate_scores['som_position_tightness']['position'] = 'tuck'
413 | else:
414 | intermediate_scores['som_position_tightness']['position'] = 'pike'
415 | intermediate_scores['som_position_tightness']['raw_score'] = position_tightness_score(dive_data)[0]
416 | intermediate_scores['som_position_tightness']['frame_indices'] = position_tightness_score(dive_data)[2]
417 | err = intermediate_scores['som_position_tightness']['raw_score']
418 | if err is not None:
419 | temp = error_scores
420 | temp.append(err)
421 | temp.sort()
422 | intermediate_scores['som_position_tightness']['percentile'] = 1-temp.index(err)/len(temp)
423 | all_percentiles.append(1-temp.index(err)/len(temp))
424 | else:
425 | intermediate_scores['som_position_tightness']['percentile'] = None
426 |
427 | ### twist_position_tightness_scores ###
428 | error_scores = distribution_data['twist_position_tightness_scores']
429 | error_scores = list(filter(lambda item: item is not None, error_scores))
430 | intermediate_scores['twist_position_tightness'] = {}
431 | intermediate_scores['twist_position_tightness']['raw_score'] = position_tightness_score(dive_data)[1]
432 | intermediate_scores['twist_position_tightness']['frame_indices'] = position_tightness_score(dive_data)[3]
433 | err = intermediate_scores['twist_position_tightness']['raw_score']
434 | if err is not None:
435 | temp = error_scores
436 | temp.append(err)
437 | temp.sort()
438 | intermediate_scores['twist_position_tightness']['percentile'] = 1-temp.index(err)/len(temp)
439 | all_percentiles.append(1-temp.index(err)/len(temp))
440 | else:
441 | intermediate_scores['twist_position_tightness']['percentile'] = None
442 |
443 | ### over_under_rotation_scores ###
444 | error_scores = distribution_data['over_under_rotation_scores']
445 | error_scores = list(filter(lambda item: item is not None, error_scores))
446 | intermediate_scores['over_under_rotation'] = {}
447 | if over_under_rotation_score(dive_data) is not None:
448 | intermediate_scores['over_under_rotation']['raw_score'] = over_under_rotation_score(dive_data)[0]
449 | intermediate_scores['over_under_rotation']['frame_index'] = over_under_rotation_score(dive_data)[1]
450 | else:
451 | intermediate_scores['over_under_rotation']['raw_score'] = None
452 | intermediate_scores['over_under_rotation']['frame_index'] = None
453 | err = intermediate_scores['over_under_rotation']['raw_score']
454 | if err is not None:
455 | temp = error_scores
456 | temp.append(err)
457 | temp.sort()
458 | intermediate_scores['over_under_rotation']['percentile'] = 1-temp.index(err)/len(temp)
459 | all_percentiles.append(1-temp.index(err)/len(temp))
460 |
461 | else:
462 | intermediate_scores['over_under_rotation']['percentile'] = None
463 |
464 | ### splash_scores ###
465 | error_scores = distribution_data['splash_scores']
466 | error_scores = list(filter(lambda item: item is not None, error_scores))
467 | intermediate_scores['splash'] = {}
468 | intermediate_scores['splash']['raw_score'] = splash_score(dive_data)[0]
469 | intermediate_scores['splash']['maximum_index'] = splash_score(dive_data)[1]
470 | intermediate_scores['splash']['frame_indices'] = splash_score(dive_data)[2]
471 |
472 | err = intermediate_scores['splash']['raw_score']
473 | if err is not None:
474 | temp = error_scores
475 | temp.append(err)
476 | temp.sort()
477 | intermediate_scores['splash']['percentile'] = 1-temp.index(err)/len(temp)
478 | all_percentiles.append(1-temp.index(err)/len(temp))
479 |
480 | else:
481 | intermediate_scores['splash']['percentile'] = None
482 |
483 | ### straightness_during_entry_scores ###
484 | error_scores = distribution_data['straightness_during_entry_scores']
485 | error_scores = list(filter(lambda item: item is not None, error_scores))
486 | intermediate_scores['straightness_during_entry'] = {}
487 | if straightness_during_entry_score(dive_data) is not None:
488 | intermediate_scores['straightness_during_entry']['raw_score'] = straightness_during_entry_score(dive_data)[0]
489 | intermediate_scores['straightness_during_entry']['frame_indices'] = straightness_during_entry_score(dive_data)[1]
490 | else:
491 | intermediate_scores['straightness_during_entry']['raw_score'] = None
492 | intermediate_scores['straightness_during_entry']['frame_index'] = None
493 |
494 | err = intermediate_scores['straightness_during_entry']['raw_score']
495 | if err is not None:
496 | temp = error_scores
497 | temp.append(err)
498 | temp.sort()
499 | intermediate_scores['straightness_during_entry']['percentile'] = 1-temp.index(err)/len(temp)
500 | all_percentiles.append(1-temp.index(err)/len(temp))
501 |
502 | else:
503 | intermediate_scores['straightness_during_entry']['percentile'] = None
504 |
505 | ## overall score ###
506 | # Excellent: 10
507 | # Very Good: 8.5-9.5
508 | # Good: 7.0-8.0
509 | # Satisfactory: 5.0-6.5
510 | # Deficient: 2.5-4.5
511 | # Unsatisfactory: 0.5-2.0
512 | # Completely failed: 0
513 | overall_score = np.mean(all_percentiles) * 10
514 | intermediate_scores['overall_score'] = {}
515 | intermediate_scores['overall_score']['raw_score'] = overall_score
516 | if overall_score == 10:
517 | intermediate_scores['overall_score']['description'] = 'excellent'
518 | elif overall_score >=8.5 and overall_score <10:
519 | intermediate_scores['overall_score']['description'] = 'very good'
520 | elif overall_score >=7 and overall_score <8.5:
521 | intermediate_scores['overall_score']['description'] = 'good'
522 | elif overall_score >=5 and overall_score <7:
523 | intermediate_scores['overall_score']['description'] = 'satisfactory'
524 | elif overall_score >=2.5 and overall_score <5:
525 | intermediate_scores['overall_score']['description'] = 'deficient'
526 | elif overall_score >0 and overall_score <2.5:
527 | intermediate_scores['overall_score']['description'] = 'unsatisfactory'
528 | else:
529 | intermediate_scores['overall_score']['description'] = 'completely failed'
530 |
531 | return intermediate_scores
532 |
--------------------------------------------------------------------------------
/score_report_generation/generate_report_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | generate_report_functions.py
3 | Author: Lauren Okamoto
4 | """
5 |
6 | from jinja2 import Environment, FileSystemLoader
7 | import pickle
8 | import os
9 | import numpy as np
10 | from PIL import Image, ImageDraw
11 | from io import BytesIO
12 | import cv2
13 | import base64
14 | from pathlib import Path
15 | import torch
16 | import gradio as gr
17 |
18 | ############ Generate GIF functions ###################################################################################
19 | def generate_gif(local_directory, image_names, speed_factor=1, loop=0):
20 | """
21 | Generate a GIF from a sequence of images paths saved in a local directory.
22 |
23 | Parameters:
24 | - local_directory (str): Directory path where the images are located
25 | - image_paths (list): List of filenames of the input images.
26 | - speed_factor (int): How fast the GIF is, the higher the less the delay is between frames
27 | - loop (int): Number of loops (0 for infinite loop).
28 |
29 | Returns:
30 | - Bytes of GIF
31 | """
32 | images = []
33 | durations = []
34 | for image_name in image_names:
35 | img = Image.open(os.path.join(local_directory, image_name))
36 | images.append(img)
37 | try:
38 | duration = img.info['duration']
39 | except KeyError:
40 | duration = 100 # Default duration in case 'duration' is not available
41 | durations.append(duration)
42 |
43 | # Calculate the adjusted durations based on the speed factor
44 | adjusted_durations = [int(duration / speed_factor) for duration in durations]
45 |
46 | # Save as GIF to an in-memory buffer
47 | gif_buffer = BytesIO()
48 | images[0].save(gif_buffer, format='GIF', save_all=True, append_images=images[1:], duration=adjusted_durations, loop=loop)
49 |
50 | # Get the content of the buffer as bytes
51 | gif_content = gif_buffer.getvalue()
52 | gif_content = base64.b64encode(gif_content).decode('utf-8')
53 | return gif_content
54 |
55 | def generate_gif_from_frames(frames, speed_factor=1, loop=0, progress=gr.Progress()):
56 | """
57 | Generate a GIF from a sequence of images.
58 |
59 | Parameters:
60 | - frames (list): List of cv2 frames
61 | - speed_factor (int): How fast the GIF is, the higher the less the delay is between frames
62 | - loop (int): Number of loops (0 for infinite loop).
63 |
64 | Returns:
65 | - Bytes of GIF
66 | """
67 | durations = []
68 | images = []
69 | i = 0
70 | for frame in frames:
71 | image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
72 | images.append(image)
73 | duration = 100 # Default duration in case 'duration' is not available
74 | durations.append(duration)
75 | i+=1
76 |
77 | # Calculate the adjusted durations based on the speed factor
78 | adjusted_durations = [int(duration / speed_factor) for duration in durations]
79 |
80 | # Save as GIF to an in-memory buffer
81 | gif_buffer = BytesIO()
82 | images[0].save(gif_buffer, format='GIF', save_all=True, append_images=images[1:], duration=adjusted_durations, loop=loop)
83 |
84 | # Get the content of the buffer as bytes
85 | gif_content = gif_buffer.getvalue()
86 | gif_content = base64.b64encode(gif_content).decode('utf-8')
87 | return gif_content
88 |
89 | ##########################################################################################################
90 | ############ Overlay Symbols on Frames ######################################################################
91 |
92 | def draw_pose(keypoints,img, board_end_coord):
93 | """draw the keypoints and the skeletons.
94 | :params keypoints: the shape should be equal to [17,2]
95 | :params img:
96 | """
97 | assert keypoints.shape == (NUM_KPTS,2)
98 | SKELETON = [[1,2],[1,0],[2,6],[3,6],[4,5],[3,4],[6,7],[7,8],[9,8],[7,12],[7,13],[11,12],[13,14],[14,15],[10,11]]
99 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
100 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
101 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
102 | NUM_KPTS = 16
103 | for i in range(len(SKELETON)):
104 | kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1]
105 | x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1]
106 | x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1]
107 | cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1)
108 | cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1)
109 | cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2)
110 |
111 | def draw_symbols(opencv_image, pose_preds, board_end_coord, plat_outputs, splash_pred_mask, above_board=None):
112 | if pose_preds is not None:
113 | draw_pose(np.array(pose_preds)[0],opencv_image, board_end_coord)
114 | if above_board is None or above_board==1:
115 | draw_platform(opencv_image, plat_outputs)
116 | draw_splash(opencv_image, splash_pred_mask)
117 | return opencv_image
118 |
119 | def draw_platform(opencv_image, output):
120 | pred_classes = output['instances'].pred_classes.cpu().numpy()
121 | platforms = np.where(pred_classes == 0)[0]
122 | scores = output['instances'].scores[platforms]
123 | if len(scores) == 0:
124 | return
125 | pred_masks = output['instances'].pred_masks[platforms]
126 | max_instance = torch.argmax(scores)
127 | pred_mask = np.array(pred_masks[max_instance].cpu())
128 | # Convert the mask to a binary image
129 | binary_mask = pred_mask.squeeze().astype(np.uint8)
130 | contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
131 |
132 | cv2.drawContours(opencv_image, contours[0], -1, (36, 255, 12), thickness=5)
133 |
134 | def draw_splash(opencv_image, pred_mask):
135 | # Convert the mask to a binary image
136 | if pred_mask is None:
137 | return
138 | binary_mask = pred_mask.squeeze().astype(np.uint8)
139 | contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
140 |
141 | cv2.drawContours(opencv_image, contours[0], -1, (0, 0, 0), thickness=5)
142 |
143 | ##########################################################################################################
144 | ############ Generate HTML template ######################################################################
145 |
146 | def generate_report(template_path, data, local_directory, progress=gr.Progress()):
147 | # Load the template environment
148 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates'))
149 |
150 | file_names = os.listdir(local_directory)
151 | file_names.sort()
152 | file_names = np.array(file_names)
153 | progress(0.9, desc="Generating Score Report")
154 | # Load the template
155 | is_twister = data['twist_position_tightness']['raw_score'] is not None
156 | overall_score_desc = data['overall_score']['description']
157 | overall_score = '%.1f' % data['overall_score']['raw_score']
158 |
159 | feet_apart_score = round(float('%.2f' % data['feet_apart']['raw_score']))
160 | feet_apart_peaks = file_names[data['feet_apart']['peaks']]
161 | feet_apart_gif = None
162 | has_feet_apart_peaks = False
163 | if len(feet_apart_peaks) > 0:
164 | has_feet_apart_peaks = True
165 | feet_apart_gif = generate_gif(local_directory, feet_apart_peaks, speed_factor = 0.05)
166 | feet_apart_percentile = round(float('%.2f' % (data['feet_apart']['percentile'] *100)))
167 | feet_apart_percentile_divided_by_ten = '%.1f' % (data['feet_apart']['percentile'] *10)
168 |
169 | include_height_off_platform = False
170 | height_off_board_score = data['height_off_board']['raw_score']
171 | height_off_board_percentile = data['height_off_board']['percentile']
172 | height_off_board_description = None
173 | height_off_board_percentile_divided_by_ten = None
174 | encoded_height_off_board_frame = None
175 | if height_off_board_score is not None:
176 | include_height_off_platform = True
177 | height_off_board_score = round(float('%.2f' % data['height_off_board']['raw_score']))
178 | height_off_board_frame = file_names[data['height_off_board']['frame_index']]
179 | height_off_board_frame_path = os.path.join(local_directory, height_off_board_frame)
180 | with open(height_off_board_frame_path, "rb") as image_file:
181 | encoded_height_off_board_frame = base64.b64encode(image_file.read()).decode('utf-8')
182 | height_off_board_percentile = round(float('%.2f' % (data['height_off_board']['percentile'] *100)))
183 | height_off_board_percentile_divided_by_ten = '%.1f' % (data['height_off_board']['percentile'] *10)
184 | if float(height_off_board_percentile_divided_by_ten) > 5:
185 | height_off_board_description = "good"
186 | else:
187 | height_off_board_description = "a bit on the lower side"
188 |
189 | dist_from_board_score = '%.2f' % data['distance_from_board']['raw_score']
190 | dist_from_board_frame = file_names[data['distance_from_board']['frame_index']]
191 | dist_from_board_frame_path = os.path.join(local_directory, dist_from_board_frame)
192 | with open(dist_from_board_frame_path, "rb") as image_file:
193 | encoded_dist_from_board_frame = base64.b64encode(image_file.read()).decode('utf-8')
194 | dist_from_board_percentile = data['distance_from_board']['percentile']
195 | if 'good' in dist_from_board_percentile:
196 | dist_from_board_percentile_status = "Good"
197 | elif 'far' in dist_from_board_percentile:
198 | dist_from_board_percentile_status = "Too Far"
199 | else:
200 | dist_from_board_percentile_status = "Too Close"
201 |
202 | knee_bend_score = data['knee_bend']['raw_score']
203 | knee_bend_percentile = data['knee_bend']['percentile']
204 | knee_bend_frames = []
205 | knee_bend_percentile_divided_by_ten = None
206 | if knee_bend_score is not None:
207 | knee_bend_score = round(float('%.2f' % (data['knee_bend']['raw_score'])))
208 | knee_bend_frames = file_names[data['knee_bend']['frame_indices']]
209 | knee_bend_percentile = round(float('%.2f' % (knee_bend_percentile * 100)))
210 | knee_bend_percentile_divided_by_ten = '%.1f' % (data['knee_bend']['percentile'] * 10)
211 |
212 | som_position_tightness_score = data['som_position_tightness']['raw_score']
213 | som_position_tightness_percentile = data['som_position_tightness']['percentile']
214 | som_position_tightness_position = data['som_position_tightness']['position']
215 | som_position_tightness_frames = file_names[data['som_position_tightness']['frame_indices']]
216 | som_position_tightness_gif = None
217 | som_position_tightness_percentile_divided_by_ten = None
218 | if som_position_tightness_score is not None:
219 | if is_twister:
220 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score'] + 15)))
221 | else:
222 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score'])))
223 | som_position_tightness_gif = generate_gif(local_directory, som_position_tightness_frames)
224 | som_position_tightness_percentile = round(float('%.2f' % (som_position_tightness_percentile * 100)))
225 | som_position_tightness_percentile_divided_by_ten = '%.1f' % (data['som_position_tightness']['percentile'] * 10)
226 | twist_position_tightness_score = data['twist_position_tightness']['raw_score']
227 | twist_position_tightness_frames = []
228 | twist_position_tightness_gif = None
229 | twist_position_tightness_percentile = None
230 | twist_position_tightness_percentile_divided_by_ten = None
231 | if twist_position_tightness_score is not None:
232 | twist_position_tightness_score = round(float('%.2f' % twist_position_tightness_score))
233 | twist_position_tightness_frames = file_names[data['twist_position_tightness']['frame_indices']]
234 | twist_position_tightness_gif = generate_gif(local_directory, twist_position_tightness_frames)
235 | twist_position_tightness_percentile = round(float('%.2f' % (data['twist_position_tightness']['percentile'] * 100)))
236 | twist_position_tightness_percentile_divided_by_ten = '%.1f' % (data['twist_position_tightness']['percentile'] * 10)
237 | over_under_rotation_score = round(float('%.2f' % data['over_under_rotation']['raw_score']))
238 | over_under_rotation_frame = file_names[data['over_under_rotation']['frame_index']]
239 | over_under_rotation_percentile = round(float('%.2f' % (data['over_under_rotation']['percentile'] * 100)))
240 | over_under_rotation_percentile_divided_by_ten = '%.1f' % (data['over_under_rotation']['percentile'] * 10)
241 | straightness_during_entry_score = round(float('%.2f' % data['straightness_during_entry']['raw_score']))
242 | straightness_during_entry_frames = file_names[data['straightness_during_entry']['frame_indices']]
243 | straightness_during_entry_gif = generate_gif(local_directory, straightness_during_entry_frames, speed_factor = 0.5)
244 | straightness_during_entry_percentile = round(float('%.2f' % (data['straightness_during_entry']['percentile'] * 100)))
245 | straightness_during_entry_percentile_divided_by_ten = '%.1f' % (data['straightness_during_entry']['percentile'] * 10)
246 | splash_score = round(float('%.2f' % data['splash']['raw_score']))
247 | splash_frame = file_names[data['splash']['maximum_index']]
248 | splash_indices = file_names[data['splash']['frame_indices']]
249 | splash_gif = None
250 | if len(splash_indices) > 0:
251 | splash_gif = generate_gif(local_directory, splash_indices)
252 | splash_percentile = round(float('%.2f' % (data['splash']['percentile'] * 100)))
253 | splash_percentile_divided_by_ten = '%.1f' % (data['splash']['percentile'] * 10)
254 | if float(splash_percentile) < 50:
255 | splash_description = 'on the larger side'
256 | else:
257 | splash_description = 'small'
258 | template = env.get_template(template_path)
259 | data = {
260 | 'local_directory': local_directory,
261 | 'is_twister' : is_twister,
262 | 'overall_score_desc' : overall_score_desc,
263 | 'overall_score' : overall_score,
264 | 'feet_apart_score' : feet_apart_score,
265 | 'feet_apart_peaks' : feet_apart_peaks,
266 | 'has_feet_apart_peaks' : has_feet_apart_peaks,
267 | 'feet_apart_gif' : feet_apart_gif,
268 | 'feet_apart_percentile' : feet_apart_percentile,
269 | 'feet_apart_percentile_divided_by_ten': feet_apart_percentile_divided_by_ten,
270 |
271 | 'include_height_off_platform': include_height_off_platform,
272 | 'height_off_board_score' : height_off_board_score,
273 | 'height_off_board_percentile' : height_off_board_percentile,
274 | 'encoded_height_off_board_frame' : encoded_height_off_board_frame,
275 | 'height_off_board_percentile_divided_by_ten': height_off_board_percentile_divided_by_ten,
276 | 'height_off_board_description' : height_off_board_description,
277 |
278 | 'dist_from_board_score' : dist_from_board_score,
279 | 'dist_from_board_frame' : dist_from_board_frame,
280 | 'encoded_dist_from_board_frame': encoded_dist_from_board_frame,
281 | 'dist_from_board_percentile' : dist_from_board_percentile,
282 | 'dist_from_board_percentile_status': dist_from_board_percentile_status,
283 |
284 | 'knee_bend_score' : knee_bend_score,
285 | 'knee_bend_frames' : knee_bend_frames,
286 | 'knee_bend_percentile' : knee_bend_percentile,
287 | 'knee_bend_percentile_divided_by_ten' : knee_bend_percentile_divided_by_ten,
288 |
289 | 'som_position_tightness_score' : som_position_tightness_score,
290 | 'som_position_tightness_frames' : som_position_tightness_frames,
291 | 'som_position_tightness_gif' : som_position_tightness_gif,
292 | 'som_position_tightness_position' : som_position_tightness_position,
293 | 'som_position_tightness_percentile' : som_position_tightness_percentile,
294 | 'som_position_tightness_percentile_divided_by_ten' : som_position_tightness_percentile_divided_by_ten,
295 | 'twist_position_tightness_score' : twist_position_tightness_score,
296 | 'twist_position_tightness_frames': twist_position_tightness_frames,
297 | 'twist_position_tightness_percentile' : twist_position_tightness_percentile,
298 | 'twist_position_tightness_percentile_divided_by_ten' : twist_position_tightness_percentile_divided_by_ten,
299 | 'twist_position_tightness_gif' : twist_position_tightness_gif,
300 | 'over_under_rotation_score' : over_under_rotation_score,
301 | 'over_under_rotation_frame' : over_under_rotation_frame,
302 | 'over_under_rotation_percentile' : over_under_rotation_percentile,
303 | 'over_under_rotation_percentile_divided_by_ten' : over_under_rotation_percentile_divided_by_ten,
304 | 'straightness_during_entry_score' : straightness_during_entry_score,
305 | 'straightness_during_entry_gif' : straightness_during_entry_gif,
306 | 'straightness_during_entry_percentile' : straightness_during_entry_percentile,
307 | 'straightness_during_entry_percentile_divided_by_ten': straightness_during_entry_percentile_divided_by_ten,
308 | 'splash_score' : splash_score,
309 | 'splash_frame' : splash_frame,
310 | 'splash_gif' : splash_gif,
311 | 'splash_percentile' : splash_percentile,
312 | 'splash_percentile_divided_by_ten': splash_percentile_divided_by_ten,
313 | 'splash_description' : splash_description,
314 | }
315 | # Render the template with the provided data
316 | report_content = template.render(data)
317 | return report_content
318 |
319 |
320 | def generate_report_from_frames(template_path, data, frames):
321 | # Load the template environment
322 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates'))
323 |
324 | frames = np.array(frames)
325 | # Load the template
326 | is_twister = data['twist_position_tightness']['raw_score'] is not None
327 | overall_score_desc = data['overall_score']['description']
328 | overall_score = '%.1f' % data['overall_score']['raw_score']
329 |
330 | feet_apart_score = round(float('%.2f' % data['feet_apart']['raw_score']))
331 | feet_apart_peaks = frames[data['feet_apart']['peaks']]
332 | feet_apart_gif = None
333 | has_feet_apart_peaks = False
334 | if len(feet_apart_peaks) > 0:
335 | has_feet_apart_peaks = True
336 | feet_apart_gif = generate_gif_from_frames(feet_apart_peaks, speed_factor = 0.05)
337 | feet_apart_percentile = round(float('%.2f' % (data['feet_apart']['percentile'] *100)))
338 | feet_apart_percentile_divided_by_ten = '%.1f' % (data['feet_apart']['percentile'] *10)
339 |
340 | include_height_off_platform = False
341 | height_off_board_score = data['height_off_board']['raw_score']
342 | height_off_board_frame = None
343 | height_off_board_percentile = None
344 | height_off_board_percentile_divided_by_ten = None
345 | height_off_board_description = None
346 | encoded_height_off_board_frame = None
347 | if height_off_board_score is not None:
348 | include_height_off_platform = True
349 | height_off_board_score = round(float('%.2f' % data['height_off_board']['raw_score']))
350 | height_off_board_frame = Image.fromarray(cv2.cvtColor(frames[data['height_off_board']['frame_index']], cv2.COLOR_BGR2RGB))
351 | height_buffer = BytesIO()
352 | height_off_board_frame.save(height_buffer, format='JPEG')
353 | encoded_height_off_board_frame = base64.b64encode(height_buffer.getvalue()).decode('utf-8')
354 | height_off_board_percentile = round(float('%.2f' % (data['height_off_board']['percentile'] *100)))
355 | height_off_board_percentile_divided_by_ten = '%.1f' % (data['height_off_board']['percentile'] *10)
356 | if float(height_off_board_percentile_divided_by_ten) > 5:
357 | height_off_board_description = "good"
358 | else:
359 | height_off_board_description = "a bit on the lower side"
360 |
361 | dist_from_board_score = '%.2f' % data['distance_from_board']['raw_score']
362 | dist_from_board_frame = Image.fromarray(cv2.cvtColor(frames[data['distance_from_board']['frame_index']], cv2.COLOR_BGR2RGB))
363 | dist_buffer = BytesIO()
364 | dist_from_board_frame.save(dist_buffer, format='JPEG')
365 | encoded_dist_from_board_frame = base64.b64encode(dist_buffer.getvalue()).decode('utf-8')
366 | dist_from_board_percentile = data['distance_from_board']['percentile']
367 | if 'good' in dist_from_board_percentile:
368 | dist_from_board_percentile_status = "Good"
369 | elif 'far' in dist_from_board_percentile:
370 | dist_from_board_percentile_status = "Too Far"
371 | else:
372 | dist_from_board_percentile_status = "Too Close"
373 |
374 | knee_bend_score = data['knee_bend']['raw_score']
375 | knee_bend_percentile = data['knee_bend']['percentile']
376 | knee_bend_frames = []
377 | knee_bend_percentile_divided_by_ten = None
378 | if knee_bend_score is not None:
379 | knee_bend_score = round(float('%.2f' % knee_bend_score))
380 | knee_bend_percentile = round(float('%.2f' % (knee_bend_percentile * 100)))
381 | knee_bend_frames = frames[data['knee_bend']['frame_indices']]
382 | knee_bend_percentile_divided_by_ten = '%.1f' % (data['knee_bend']['percentile'] * 10)
383 |
384 | som_position_tightness_score = data['som_position_tightness']['raw_score']
385 | som_position_tightness_percentile = data['som_position_tightness']['percentile']
386 | som_position_tightness_position = data['som_position_tightness']['position']
387 | som_position_tightness_frames = []
388 | som_position_tightness_gif = None
389 | som_position_tightness_percentile_divided_by_ten = None
390 | if som_position_tightness_score is not None:
391 | if is_twister:
392 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score'] + 15)))
393 | else:
394 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score'])))
395 | som_position_tightness_frames = frames[data['som_position_tightness']['frame_indices']]
396 | som_position_tightness_gif = generate_gif_from_frames(som_position_tightness_frames)
397 | som_position_tightness_percentile = round(float('%.2f' % (som_position_tightness_percentile * 100)))
398 | som_position_tightness_percentile_divided_by_ten = '%.1f' % (data['som_position_tightness']['percentile'] * 10)
399 |
400 | twist_position_tightness_score = data['twist_position_tightness']['raw_score']
401 | twist_position_tightness_frames = []
402 | twist_position_tightness_gif = None
403 | twist_position_tightness_percentile = None
404 | twist_position_tightness_percentile_divided_by_ten = None
405 | if twist_position_tightness_score is not None:
406 | twist_position_tightness_score = round(float('%.2f' % twist_position_tightness_score))
407 | twist_position_tightness_frames = frames[data['twist_position_tightness']['frame_indices']]
408 | twist_position_tightness_gif = generate_gif_from_frames(twist_position_tightness_frames)
409 | twist_position_tightness_percentile = round(float('%.2f' % (data['twist_position_tightness']['percentile'] * 100)))
410 | twist_position_tightness_percentile_divided_by_ten = '%.1f' % (data['twist_position_tightness']['percentile'] * 10)
411 |
412 | over_under_rotation_score = round(float('%.2f' % data['over_under_rotation']['raw_score']))
413 | over_under_rotation_frame = frames[data['over_under_rotation']['frame_index']]
414 | over_under_rotation_percentile = round(float('%.2f' % (data['over_under_rotation']['percentile'] * 100)))
415 | over_under_rotation_percentile_divided_by_ten = '%.1f' % (data['over_under_rotation']['percentile'] * 10)
416 |
417 | straightness_during_entry_score = round(float('%.2f' % data['straightness_during_entry']['raw_score']))
418 | straightness_during_entry_frames = frames[data['straightness_during_entry']['frame_indices']]
419 | straightness_during_entry_gif = generate_gif_from_frames(straightness_during_entry_frames, speed_factor = 0.5)
420 | straightness_during_entry_percentile = round(float('%.2f' % (data['straightness_during_entry']['percentile'] * 100)))
421 | straightness_during_entry_percentile_divided_by_ten = '%.1f' % (data['straightness_during_entry']['percentile'] * 10)
422 |
423 | splash_score = round(float('%.2f' % data['splash']['raw_score']))
424 | splash_frame = frames[data['splash']['maximum_index']]
425 | splash_indices = frames[data['splash']['frame_indices']]
426 | splash_gif = None
427 | if len(splash_indices) > 0:
428 | splash_gif = generate_gif_from_frames(splash_indices)
429 | splash_percentile = round(float('%.2f' % (data['splash']['percentile'] * 100)))
430 | splash_percentile_divided_by_ten = '%.1f' % (data['splash']['percentile'] * 10)
431 | if float(splash_percentile) < 50:
432 | splash_description = 'on the larger side'
433 | else:
434 | splash_description = 'small'
435 | template = env.get_template(template_path)
436 | data = {
437 | 'is_twister' : is_twister,
438 | 'overall_score_desc' : overall_score_desc,
439 | 'overall_score' : overall_score,
440 | 'feet_apart_score' : feet_apart_score,
441 | 'feet_apart_peaks' : feet_apart_peaks,
442 | 'has_feet_apart_peaks' : has_feet_apart_peaks,
443 | 'feet_apart_gif' : feet_apart_gif,
444 | 'feet_apart_percentile' : feet_apart_percentile,
445 | 'feet_apart_percentile_divided_by_ten': feet_apart_percentile_divided_by_ten,
446 | 'include_height_off_platform': include_height_off_platform,
447 | 'height_off_board_score' : height_off_board_score,
448 | 'height_off_board_percentile' : height_off_board_percentile,
449 | 'encoded_height_off_board_frame' : encoded_height_off_board_frame,
450 | 'height_off_board_percentile_divided_by_ten': height_off_board_percentile_divided_by_ten,
451 | 'height_off_board_description' : height_off_board_description,
452 | 'dist_from_board_score' : dist_from_board_score,
453 | 'dist_from_board_frame' : dist_from_board_frame,
454 | 'encoded_dist_from_board_frame': encoded_dist_from_board_frame,
455 | 'dist_from_board_percentile' : dist_from_board_percentile,
456 | 'dist_from_board_percentile_status': dist_from_board_percentile_status,
457 | 'knee_bend_score' : knee_bend_score,
458 | 'knee_bend_frames' : knee_bend_frames,
459 | 'knee_bend_percentile' : knee_bend_percentile,
460 | 'knee_bend_percentile_divided_by_ten' : knee_bend_percentile_divided_by_ten,
461 | 'som_position_tightness_score' : som_position_tightness_score,
462 | 'som_position_tightness_frames' : som_position_tightness_frames,
463 | 'som_position_tightness_gif' : som_position_tightness_gif,
464 | 'som_position_tightness_position' : som_position_tightness_position,
465 | 'som_position_tightness_percentile' : som_position_tightness_percentile,
466 | 'som_position_tightness_percentile_divided_by_ten' : som_position_tightness_percentile_divided_by_ten,
467 | 'twist_position_tightness_score' : twist_position_tightness_score,
468 | 'twist_position_tightness_frames': twist_position_tightness_frames,
469 | 'twist_position_tightness_percentile' : twist_position_tightness_percentile,
470 | 'twist_position_tightness_percentile_divided_by_ten' : twist_position_tightness_percentile_divided_by_ten,
471 | 'twist_position_tightness_gif' : twist_position_tightness_gif,
472 | 'over_under_rotation_score' : over_under_rotation_score,
473 | 'over_under_rotation_frame' : over_under_rotation_frame,
474 | 'over_under_rotation_percentile' : over_under_rotation_percentile,
475 | 'over_under_rotation_percentile_divided_by_ten' : over_under_rotation_percentile_divided_by_ten,
476 | 'straightness_during_entry_score' : straightness_during_entry_score,
477 | 'straightness_during_entry_gif' : straightness_during_entry_gif,
478 | 'straightness_during_entry_percentile' : straightness_during_entry_percentile,
479 | 'straightness_during_entry_percentile_divided_by_ten': straightness_during_entry_percentile_divided_by_ten,
480 | 'splash_score' : splash_score,
481 | 'splash_frame' : splash_frame,
482 | 'splash_gif' : splash_gif,
483 | 'splash_percentile' : splash_percentile,
484 | 'splash_percentile_divided_by_ten': splash_percentile_divided_by_ten,
485 | 'splash_description' : splash_description,
486 | }
487 | # Render the template with the provided data
488 | report_content = template.render(data)
489 | return report_content
490 |
491 | def generate_symbols_report(template_path, dive_data, frames):
492 | # Load the template environment
493 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates'))
494 | template = env.get_template(template_path)
495 | pose_frames = []
496 | for i in range(len(dive_data['pose_pred'])):
497 | pose_frame = draw_symbols(frames[i], dive_data['pose_pred'][i], dive_data['board_end_coords'][i], dive_data['plat_outputs'][i], dive_data['splash_pred_masks'][i])
498 | pose_frames.append(pose_frame)
499 | pose_gif = generate_gif_from_frames(pose_frames, speed_factor=2)
500 | pose_data = {}
501 | pose_data['pose_gif'] = pose_gif
502 | html = template.render(pose_data)
503 | return html
504 |
505 | def generate_symbols_report_precomputed(template_path, dive_data, local_directory, progress=gr.Progress()):
506 | # Load the template environment
507 | file_names = os.listdir(local_directory)
508 | file_names.sort()
509 | file_names = np.array(file_names)
510 |
511 | if 'above_boards' in dive_data:
512 | above_boards = dive_data['above_boards']
513 | else:
514 | above_boards = [None] * len(file_names)
515 |
516 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates'))
517 | template = env.get_template(template_path)
518 | pose_frames = []
519 | counter = 0
520 | for i in range(len(file_names)):
521 | progress(i/(len(file_names)+10), desc="Abstracting Symbols")
522 | if file_names[i][-4:] != ".jpg":
523 | continue
524 | opencv_image = cv2.imread(local_directory+file_names[i])
525 | pose_frame = draw_symbols(opencv_image, dive_data['pose_pred'][counter], dive_data['board_end_coords'][counter], dive_data['plat_outputs'][counter], dive_data['splash_pred_masks'][counter], above_board=above_boards[counter])
526 | pose_frames.append(pose_frame)
527 | counter +=1
528 | pose_gif = generate_gif_from_frames(pose_frames, speed_factor=2, progress=progress)
529 | pose_data = {}
530 | pose_data['pose_gif'] = pose_gif
531 | html = template.render(pose_data)
532 | return html
533 |
--------------------------------------------------------------------------------