├── models ├── pose_estimator │ ├── cfg │ │ ├── readme.txt │ │ └── w32_256x256_adam_lr1e-3.yaml │ ├── readme.txt │ ├── model_weights │ │ └── readme.txt │ └── pose_estimator_model_setup.py └── detectron2 │ ├── readme.txt │ ├── model_weights │ └── readme.txt │ └── detectors.py ├── rule_based_programs ├── microprograms │ ├── readme.txt │ ├── temporal_segmentation_functions.py │ ├── dive_error_functions.py │ └── dive_recognition_functions.py ├── README.md ├── aqa_metaProgram_finediving.py ├── aqa_metaProgram.py └── scoring_functions.py ├── score_report_generation ├── templates │ ├── readme.txt │ └── report_template_tables.html └── generate_report_functions.py ├── teaser_fig.png ├── README.md ├── nsaqa.py └── nsaqa_finediving.py /models/pose_estimator/cfg/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /rule_based_programs/microprograms/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /score_report_generation/templates/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /teaser_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/laurenok24/NSAQA/HEAD/teaser_fig.png -------------------------------------------------------------------------------- /models/detectron2/readme.txt: -------------------------------------------------------------------------------- 1 | You will need to install the detectron2 model: https://github.com/facebookresearch/detectron2 2 | -------------------------------------------------------------------------------- /models/pose_estimator/readme.txt: -------------------------------------------------------------------------------- 1 | You will need to install the HRNet model: https://github.com/leoxiaobin/deep-high-resolution-net.pytorch 2 | 3 | Additionally, you will need to add pose_hrnet.py (found at https://github.com/leoxiaobin/deep-high-resolution-net.pytorch/blob/master/lib/models/pose_hrnet.py) to this folder. 4 | -------------------------------------------------------------------------------- /models/pose_estimator/model_weights/readme.txt: -------------------------------------------------------------------------------- 1 | Download model weights for HRNet: https://drive.google.com/drive/folders/1sdiGj9lnhNi0-Nix-wSSQeMgYHPMSFCQ?usp=sharing 2 | 3 | If you wish to train the models on your own, here are the annotations we hand annotated for training: https://drive.google.com/drive/folders/1_lKxYwdvWRqxt3YkV4sABbbY7fWhaQSA?usp=sharing 4 | -------------------------------------------------------------------------------- /models/detectron2/model_weights/readme.txt: -------------------------------------------------------------------------------- 1 | Download model weights for platform, splash, and diver detection: 2 | https://drive.google.com/drive/folders/1sdiGj9lnhNi0-Nix-wSSQeMgYHPMSFCQ?usp=sharing 3 | 4 | If you wish to train the models on your own, here are the annotations we hand annotated for training: https://drive.google.com/drive/folders/1_lKxYwdvWRqxt3YkV4sABbbY7fWhaQSA?usp=sharing 5 | -------------------------------------------------------------------------------- /rule_based_programs/README.md: -------------------------------------------------------------------------------- 1 | Run all code from the root. 2 | 3 | aqa_metaProgram.py: 4 | Extract information from a dive clip necessary for scoring. Extracted data will be saved as a pickle file. 5 | ``` 6 | python rule_based_programs/aqa_metaProgram.py path/to/video.mp4 7 | ``` 8 | 9 | aqa_metaProgram_finediving.py: 10 | Extract information from dive frames in the FineDiving dataset found at: https://github.com/xujinglin/FineDiving/tree/main. 11 | Each dive in this dataset has an instance id (x, y). 12 | ``` 13 | python rule_based_programs/aqa_metaProgram_finediving.py x y 14 | 15 | # e.g. for instance id ('01', 1) 16 | python rule_based_programs/aqa_metaProgram_finediving.py 01 1 17 | ``` 18 | 19 | scoring_functions.py: 20 | You'll need to download [distribution_data.pkl](https://drive.google.com/file/d/1Dbc7LVeuo3d8RVtx0fWdjgoqjbNCtnuk/view?usp=sharing) in order to generate scores. This file contains the distribution of data (from the dives in the FineDiving dataset) used to calculate the percentile scores. 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ⚠️ ***Please note that all materials (codes, dataset, etc.) are available for non-commercial use only.*** ⚠️ Please refrain from using the materials for any commercial purposes in any form or capacity. 2 | 3 | # Neuro-Symbolic Action Quality Assessment (NS-AQA) 4 | ## 🏆 CVPR 2024 CVSports Best Paper Award 5 | This repository contains the Python code implementation of NS-AQA for platform diving. 6 | 7 | 📝 **Technical Paper:** [link](https://arxiv.org/abs/2403.13798) 8 | 9 | 🤗 **Huggingface Demo:** [link](https://huggingface.co/spaces/X-NS/NSAQA) 10 | 11 | ## Overview 12 | We propose a neuro-symbolic paradigm for AQA. 13 | ![NS-AQA Concept](teaser_fig.png) 14 | 15 | ## Run NS-AQA 16 | Score Report is saved as an HTML file at "./output/" 17 | 18 | Run NS-AQA on a single dive clip. 19 | ``` 20 | python nsaqa.py path/to/video.mp4 21 | ``` 22 | Run NS-AQA on a single dive from the [FineDiving Dataset](https://github.com/xujinglin/FineDiving). Each dive in the dataset has an instance id (x, y). 23 | ``` 24 | python nsaqa_finediving.py x y 25 | 26 | # e.g. if the instance id is ('01', 1) 27 | python nsaqa_finediving.py 01 1 28 | ``` 29 | 30 | ## Please consider citing our work: 31 | ``` 32 | @inproceedings{okamoto2024hierarchical, 33 | title={Hierarchical NeuroSymbolic Approach for Comprehensive and Explainable Action Quality Assessment}, 34 | author={Okamoto, Lauren and Parmar, Paritosh}, 35 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 36 | pages={3204--3213}, 37 | year={2024} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /nsaqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | nsaqa.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | import pickle 7 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector 8 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model 9 | from rule_based_programs.scoring_functions import * 10 | from score_report_generation.generate_report_functions import * 11 | from rule_based_programs.aqa_metaProgram import aqa_metaprogram, abstractSymbols, extract_frames 12 | import argparse 13 | 14 | def main(video_path): 15 | platform_detector = get_platform_detector() 16 | splash_detector = get_splash_detector() 17 | diver_detector = get_diver_detector() 18 | pose_model = get_pose_model() 19 | template_path = 'report_template_tables.html' 20 | dive_data = {} 21 | 22 | frames = extract_frames(video_path) 23 | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 24 | dive_data = aqa_metaprogram(frames, dive_data, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 25 | intermediate_scores = get_all_report_scores(dive_data) 26 | html = generate_report_from_frames(template_path, intermediate_scores, frames) 27 | save_path = "./output/{}_report.html".format("".join(video_path.split('.')[:-1])) 28 | with open(save_path, 'w') as f: 29 | print("saving html report into " + save_path) 30 | f.write(html) 31 | 32 | 33 | if __name__ == '__main__': 34 | # Set up command-line arguments 35 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.") 36 | new_parser.add_argument("video_path", type=str, help="Path to dive video (mp4 format).") 37 | meta_program_args = new_parser.parse_args() 38 | video_path = meta_program_args.video_path 39 | 40 | main(video_path) 41 | -------------------------------------------------------------------------------- /nsaqa_finediving.py: -------------------------------------------------------------------------------- 1 | """ 2 | nsaqa_finediving.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | import pickle 7 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector 8 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model 9 | from rule_based_programs.scoring_functions import * 10 | from score_report_generation.generate_report_functions import * 11 | from rule_based_programs.aqa_metaProgram_finediving import aqa_metaprogram_finediving 12 | import argparse 13 | 14 | def main(key): 15 | platform_detector = get_platform_detector() 16 | splash_detector = get_splash_detector() 17 | diver_detector = get_diver_detector() 18 | pose_model = get_pose_model() 19 | 20 | # Fine-grained annotations from FineDiving Dataset 21 | with open('FineDiving/Annotations/fine-grained_annotation_aqa.pkl', 'rb') as f: 22 | dive_annotation_data = pickle.load(f) 23 | diveNum = dive_annotation_data[key][0] 24 | template_path = 'report_template_tables.html' 25 | dive_data = {} 26 | 27 | dive_data = aqa_metaprogram_finediving(key[0], key[1], diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 28 | intermediate_scores = get_all_report_scores(dive_data) 29 | 30 | local_directory = "FineDiving/datasets/FINADiving_MTL_256s/{}/{}/".format(key[0], key[1]) 31 | html = generate_report(template_path, intermediate_scores, local_directory) 32 | save_path = "./output/{}_{}_report.pkl".format(key[0], key[1]) 33 | with open(save_path, 'w') as f: 34 | print("saving html report into " + save_path) 35 | f.write(html) 36 | 37 | 38 | if __name__ == '__main__': 39 | # Set up command-line arguments 40 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.") 41 | new_parser.add_argument("FineDiving_key", type=str, nargs=2, help="key from FineDiving Dataset (e.g. 01 1)") 42 | meta_program_args = new_parser.parse_args() 43 | key = tuple(meta_program_args.FineDiving_key) 44 | key = (key[0], int(key[1])) 45 | print(key) 46 | 47 | main(key) 48 | -------------------------------------------------------------------------------- /models/detectron2/detectors.py: -------------------------------------------------------------------------------- 1 | """ 2 | detectors.py 3 | Author: Lauren Okamoto 4 | 5 | Code used to initialize the object detector models to be used for inference. 6 | """ 7 | 8 | import sys, os, distutils.core 9 | sys.path.insert(0, os.path.abspath('./detectron2')) 10 | 11 | import detectron2 12 | import cv2 13 | 14 | from detectron2.utils.logger import setup_logger 15 | setup_logger() 16 | 17 | from detectron2 import model_zoo 18 | from detectron2.engine import DefaultPredictor 19 | from detectron2.config import get_cfg 20 | from detectron2.utils.visualizer import Visualizer 21 | from detectron2.data import MetadataCatalog, DatasetCatalog 22 | from detectron2.checkpoint import DetectionCheckpointer 23 | from detectron2.data.datasets import register_coco_instances 24 | 25 | 26 | def get_diver_detector(): 27 | cfg = get_cfg() 28 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 29 | cfg.OUTPUT_DIR = "./models/detectron2/model_weights/" 30 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "diver_model_final.pth") 31 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 32 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 33 | diver_detector = DefaultPredictor(cfg) 34 | return diver_detector 35 | 36 | 37 | def get_platform_detector(): 38 | cfg = get_cfg() 39 | cfg.OUTPUT_DIR = "./models/detectron2/model_weights/" 40 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 41 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 42 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "plat_model_final.pth") 43 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 44 | platform_detector = DefaultPredictor(cfg) 45 | return platform_detector 46 | 47 | def get_splash_detector(): 48 | cfg = get_cfg() 49 | cfg.OUTPUT_DIR = "./models/detectron2/model_weights/" 50 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 51 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2 52 | cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "splash_model_final.pth") 53 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 54 | splash_detector = DefaultPredictor(cfg) 55 | return splash_detector -------------------------------------------------------------------------------- /models/pose_estimator/cfg/w32_256x256_adam_lr1e-3.yaml: -------------------------------------------------------------------------------- 1 | AUTO_RESUME: true 2 | CUDNN: 3 | BENCHMARK: true 4 | DETERMINISTIC: false 5 | ENABLED: true 6 | DATA_DIR: '' 7 | GPUS: (0,) 8 | OUTPUT_DIR: 'output' 9 | LOG_DIR: 'log' 10 | WORKERS: 8 11 | PRINT_FREQ: 100 12 | 13 | DATASET: 14 | COLOR_RGB: true 15 | DATASET: mpii 16 | DATA_FORMAT: jpg 17 | FLIP: true 18 | NUM_JOINTS_HALF_BODY: 8 19 | PROB_HALF_BODY: -1.0 20 | ROOT: 'deep-high-resolution-net.pytorch/data/mpii/' 21 | ROT_FACTOR: 30 22 | SCALE_FACTOR: 0.25 23 | TEST_SET: valid 24 | TRAIN_SET: train_all 25 | MODEL: 26 | INIT_WEIGHTS: true 27 | NAME: pose_hrnet 28 | NUM_JOINTS: 16 29 | PRETRAINED: 'deep-high-resolution-net.pytorch/models/pytorch/pose_mpii/pose_hrnet_w32_256x256.pth' 30 | TARGET_TYPE: gaussian 31 | IMAGE_SIZE: 32 | - 256 33 | - 256 34 | HEATMAP_SIZE: 35 | - 64 36 | - 64 37 | SIGMA: 2 38 | EXTRA: 39 | PRETRAINED_LAYERS: 40 | - 'conv1' 41 | - 'bn1' 42 | - 'conv2' 43 | - 'bn2' 44 | - 'layer1' 45 | - 'transition1' 46 | - 'stage2' 47 | - 'transition2' 48 | - 'stage3' 49 | - 'transition3' 50 | - 'stage4' 51 | FINAL_CONV_KERNEL: 1 52 | STAGE2: 53 | NUM_MODULES: 1 54 | NUM_BRANCHES: 2 55 | BLOCK: BASIC 56 | NUM_BLOCKS: 57 | - 4 58 | - 4 59 | NUM_CHANNELS: 60 | - 32 61 | - 64 62 | FUSE_METHOD: SUM 63 | STAGE3: 64 | NUM_MODULES: 4 65 | NUM_BRANCHES: 3 66 | BLOCK: BASIC 67 | NUM_BLOCKS: 68 | - 4 69 | - 4 70 | - 4 71 | NUM_CHANNELS: 72 | - 32 73 | - 64 74 | - 128 75 | FUSE_METHOD: SUM 76 | STAGE4: 77 | NUM_MODULES: 3 78 | NUM_BRANCHES: 4 79 | BLOCK: BASIC 80 | NUM_BLOCKS: 81 | - 4 82 | - 4 83 | - 4 84 | - 4 85 | NUM_CHANNELS: 86 | - 32 87 | - 64 88 | - 128 89 | - 256 90 | FUSE_METHOD: SUM 91 | LOSS: 92 | USE_TARGET_WEIGHT: true 93 | TRAIN: 94 | BATCH_SIZE_PER_GPU: 64 95 | SHUFFLE: true 96 | BEGIN_EPOCH: 0 97 | END_EPOCH: 210 98 | OPTIMIZER: adam 99 | LR: 0.001 100 | LR_FACTOR: 0.1 101 | LR_STEP: 102 | - 170 103 | - 200 104 | WD: 0.0001 105 | GAMMA1: 0.99 106 | GAMMA2: 0.0 107 | MOMENTUM: 0.9 108 | NESTEROV: false 109 | TEST: 110 | BATCH_SIZE_PER_GPU: 64 111 | MODEL_FILE: './models/pose_estimator/model_weights/diver_pose_model_best.pth' 112 | FLIP_TEST: true 113 | POST_PROCESS: true 114 | SHIFT_HEATMAP: true 115 | DEBUG: 116 | DEBUG: true 117 | SAVE_BATCH_IMAGES_GT: true 118 | SAVE_BATCH_IMAGES_PRED: true 119 | SAVE_HEATMAPS_GT: true 120 | SAVE_HEATMAPS_PRED: true 121 | -------------------------------------------------------------------------------- /rule_based_programs/microprograms/temporal_segmentation_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | temporal_segmentation_functions.py 3 | Author: Lauren Okamoto 4 | 5 | Temporal Segmentation Microprograms for detecting start/takeoff, somersault, twist, and entry phases 6 | """ 7 | 8 | from rule_based_programs.microprograms.dive_error_functions import get_splash_from_one_frame, applyPositionTightnessError 9 | import numpy as np 10 | 11 | """ 12 | Start/Takeoff Phase Microprogram 13 | 14 | Parameters: 15 | - filepath (str): file path where the frame is located 16 | - above_board (bool): True if the diver is above the board at this frame 17 | - on_board (bool): True if the diver is on the board at this frame 18 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected) 19 | 20 | Returns: 21 | - 0 if frame is not in start/takeoff phase 22 | - 1 if frame is in start/takeoff phase 23 | """ 24 | def takeoff_microprogram_one_frame(filepath, above_board, on_board, pose_pred=None): 25 | if not above_board: 26 | return 0 27 | if on_board: 28 | return 1 29 | return 0 30 | 31 | """ 32 | Somersault Phase Microprogram 33 | 34 | Parameters: 35 | - filepath (str): file path where the frame is located 36 | - on_board (bool): True if the diver is on the board at this frame 37 | - expected_som (int): number of somersaults in full dive (from action recognition) 38 | - half_som_count (int): number of somersaults counted by this frame 39 | - expected_twists (int): number of twists in full dive (from action recognition) 40 | - petal_count (int): number of twists counted by this frame 41 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected) 42 | - diver_detector: diver detector model 43 | - pose_model: pose estimator model 44 | 45 | Returns: 46 | - 0 if frame is not in somersault phase 47 | - 1 if frame is in somersault phase 48 | """ 49 | def somersault_microprogram_one_frame(filepath, on_board, expected_som, half_som_count, expected_twists, petal_count, pose_pred=None, diver_detector=None, pose_model=None): 50 | if on_board: 51 | return 0 52 | if expected_som <= half_som_count: 53 | return 0 54 | # if not done with som or twists, need to determine if som or twist 55 | angle = applyPositionTightnessError(filepath, pose_pred, diver_detector=diver_detector, pose_model=pose_model) 56 | if angle is None: 57 | return 0 58 | # if not done with som but done with twists 59 | if expected_som > half_som_count and expected_twists <= petal_count: 60 | return 1 61 | # print("angle:", angle) 62 | if angle <= 80: 63 | return 1 64 | else: 65 | return 0 66 | 67 | """ 68 | Twist Phase Microprogram 69 | 70 | Parameters: 71 | - filepath (str): file path where the frame is located 72 | - on_board (bool): True if the diver is on the board at this frame 73 | - expected_twists (int): number of twists in full dive (from action recognition) 74 | - petal_count (int): number of twists counted by this frame 75 | - expected_som (int): number of somersaults in full dive (from action recognition) 76 | - half_som_count (int): number of somersaults counted by this frame 77 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected) 78 | - diver_detector: diver detector model 79 | - pose_model: pose estimator model 80 | 81 | Returns: 82 | - 0 if frame is not in twist phase 83 | - 1 if frame is in twist phase 84 | """ 85 | def twist_microprogram_one_frame(filepath, on_board, expected_twists, petal_count, expected_som, half_som_count, pose_pred=None, diver_detector=None, pose_model=None): 86 | if on_board: 87 | return 0 88 | if expected_twists <= petal_count or expected_som <= half_som_count: 89 | return 0 90 | angle = applyPositionTightnessError(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model) 91 | if angle is None: 92 | return 0 93 | if angle > 80: 94 | return 1 95 | else: 96 | return 0 97 | 98 | """ 99 | Entry Phase Microprogram 100 | 101 | Parameters: 102 | - filepath (str): file path where the frame is located 103 | - above_board (bool): True if the diver is above the board at this frame 104 | - on_board (bool): True if the diver is on the board at this frame 105 | - pose_pred: pose estimation of the diver at this frame (None if no diver detected) 106 | - expected_twists (int): number of twists in full dive (from action recognition) 107 | - petal_count (int): number of twists counted by this frame 108 | - expected_som (int): number of somersaults in full dive (from action recognition) 109 | - half_som_count (int): number of somersaults counted by this frame 110 | - frame: Pil Image of frame 111 | - splash_detector: splash detector model 112 | - visualize: True if you want to save the splash segmentation mask prediction to an image 113 | - dive_folder_num: if visualize is true, this is where the image will be saved 114 | 115 | Returns: 116 | - 0 if frame is not in entry phase 117 | - 1 if frame is in entry phase 118 | """ 119 | def entry_microprogram_one_frame(filepath, above_board, on_board, pose_pred, expected_twists, petal_count, expected_som, half_som_count, frame=None, splash_detector=None, visualize=False, dive_folder_num=None): 120 | if above_board: 121 | return 0 122 | if on_board: 123 | return 0 124 | splash = get_splash_from_one_frame(filepath, im=frame, predictor=splash_detector, visualize=visualize, dive_folder_num=dive_folder_num) 125 | if splash: 126 | return 1 127 | # if completed with somersaults, we know we're in entry phase 128 | if not expected_som > half_som_count: 129 | return 1 130 | if expected_twists > petal_count or expected_som > half_som_count: 131 | return 0 132 | return 1 133 | -------------------------------------------------------------------------------- /models/pose_estimator/pose_estimator_model_setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | pose_estimator_model_setup.py 3 | Author: Lauren Okamoto 4 | 5 | Code used to initialize the pose estimator model, HRNet, combined with the diver detectron2 model to be used for diver pose inference. 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import csv 14 | import os 15 | import shutil 16 | import sys 17 | 18 | from PIL import Image 19 | import torch 20 | import torch.nn.parallel 21 | import torch.backends.cudnn as cudnn 22 | import torch.optim 23 | import torch.utils.data 24 | import torch.utils.data.distributed 25 | import torchvision.transforms as transforms 26 | import torchvision 27 | import cv2 28 | import numpy as np 29 | import time 30 | sys.path.append('./deep-high-resolution-net.pytorch/lib') 31 | import models 32 | from config import cfg 33 | from config import update_config 34 | from core.function import get_final_preds 35 | from utils.transforms import get_affine_transform 36 | 37 | import distutils.core 38 | 39 | from models.detectron2.diver_detector_setup import get_diver_detector 40 | from models.pose_estimator.pose_hrnet import get_pose_net 41 | 42 | 43 | def box_to_center_scale(box, model_image_width, model_image_height): 44 | """convert a box to center,scale information required for pose transformation 45 | Parameters 46 | ---------- 47 | box : list of tuple 48 | list of length 2 with two tuples of floats representing 49 | bottom left and top right corner of a box 50 | model_image_width : int 51 | model_image_height : int 52 | 53 | Returns 54 | ------- 55 | (numpy array, numpy array) 56 | Two numpy arrays, coordinates for the center of the box and the scale of the box 57 | """ 58 | center = np.zeros((2), dtype=np.float32) 59 | bottom_left_corner = (box[0].data.cpu().item(), box[1].data.cpu().item()) 60 | top_right_corner = (box[2].data.cpu().item(), box[3].data.cpu().item()) 61 | box_width = top_right_corner[0]-bottom_left_corner[0] 62 | box_height = top_right_corner[1]-bottom_left_corner[1] 63 | bottom_left_x = bottom_left_corner[0] 64 | bottom_left_y = bottom_left_corner[1] 65 | center[0] = bottom_left_x + box_width * 0.5 66 | center[1] = bottom_left_y + box_height * 0.5 67 | aspect_ratio = model_image_width * 1.0 / model_image_height 68 | pixel_std = 200 69 | if box_width > aspect_ratio * box_height: 70 | box_height = box_width * 1.0 / aspect_ratio 71 | elif box_width < aspect_ratio * box_height: 72 | box_width = box_height * aspect_ratio 73 | scale = np.array( 74 | [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std], 75 | dtype=np.float32) 76 | if center[0] != -1: 77 | scale = scale * 1.25 78 | return center, scale 79 | 80 | def parse_args(): 81 | parser = argparse.ArgumentParser(description='Train keypoints network') 82 | parser.add_argument('--cfg', type=str, default='./models/pose_estimator/cfg/w32_256x256_adam_lr1e-3.yaml') 83 | parser.add_argument('opts', 84 | help='Modify config options using the command-line', 85 | default=None, 86 | nargs=argparse.REMAINDER) 87 | args = parser.parse_args() 88 | args.opts = '' 89 | args.modelDir = '' 90 | args.logDir = '' 91 | args.dataDir = '' 92 | args.prevModelDir = '' 93 | return args 94 | 95 | def get_pose_estimation_prediction(pose_model, image, center, scale): 96 | rotation = 0 97 | trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE) 98 | transform = transforms.Compose([ 99 | transforms.ToTensor(), 100 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 101 | std=[0.229, 0.224, 0.225]), 102 | ]) 103 | model_input = cv2.warpAffine( 104 | image, 105 | trans, 106 | (256, 256), 107 | flags=cv2.INTER_LINEAR) 108 | 109 | # pose estimation inference 110 | model_input = transform(model_input).unsqueeze(0) 111 | # switch to evaluate mode 112 | pose_model.eval() 113 | with torch.no_grad(): 114 | # compute output heatmap 115 | output = pose_model(model_input) 116 | preds, _ = get_final_preds( 117 | cfg, 118 | output.clone().cpu().numpy(), 119 | np.asarray([center]), 120 | np.asarray([scale])) 121 | return preds 122 | 123 | def get_pose_model(): 124 | CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 125 | cudnn.benchmark = cfg.CUDNN.BENCHMARK 126 | torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC 127 | torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED 128 | args = parse_args() 129 | update_config(cfg, args) 130 | pose_model = get_pose_net(cfg, is_train=False) 131 | if cfg.TEST.MODEL_FILE: 132 | print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) 133 | pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) 134 | else: 135 | print('expected model defined in config at TEST.MODEL_FILE') 136 | pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS) 137 | pose_model.to(CTX) 138 | pose_model.eval() 139 | return pose_model 140 | 141 | def get_pose_estimation(filepath, image_bgr=None, diver_detector=None, pose_model=None): 142 | if image_bgr is None: 143 | image_bgr = cv2.imread(filepath) 144 | if image_bgr is None: 145 | print("ERROR: image {} does not exist".format(filepath)) 146 | return None 147 | if diver_detector is None: 148 | diver_detector = get_diver_detector() 149 | 150 | if pose_model is None: 151 | pose_model = get_pose_model() 152 | 153 | image = image_bgr[:, :, [2, 1, 0]] 154 | 155 | outputs = diver_detector(image_bgr) 156 | scores = outputs['instances'].scores 157 | pred_boxes = [] 158 | if len(scores) > 0: 159 | pred_boxes = outputs['instances'].pred_boxes 160 | 161 | if len(pred_boxes) >= 1: 162 | for box in pred_boxes: 163 | center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]) 164 | image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy() 165 | box = box.detach().cpu().numpy() 166 | return box, get_pose_estimation_prediction(pose_model, image_pose, center, scale) 167 | return None, None 168 | -------------------------------------------------------------------------------- /score_report_generation/templates/report_template_tables.html: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 9 | 10 | 11 | 12 | 13 | 14 |
Your dive was {{overall_score_desc}}, and scored a {{overall_score}}. Here is how we scored each component of the dive. Each percentile is relative to dives in the semifinals and finals of Olympic and World Championship competitions.
15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 26 | 34 | 37 | 38 | {% if include_height_off_platform %} 39 | 40 | 41 | 42 | 45 | 48 | 49 | {% endif %} 50 | 51 | 52 | 53 | 57 | 60 | 61 | {% if som_position_tightness_frames|length > 0 %} 62 | 63 | 64 | 68 | {% if knee_bend_frames|length > 0 %} 69 | 73 | {% else %} 74 | 78 | {% endif %} 79 | 82 | 83 | {% endif %} 84 | {% if knee_bend_frames|length > 0 %} 85 | 86 | 87 | 91 | 94 | 95 | {% endif %} 96 | {% if is_twister %} 97 | 98 | 101 | 105 | 109 | 112 | 113 | {% endif %} 114 | 115 | 118 | 121 | 125 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 141 | 144 | 145 | 146 |
ErrorDescriptionVisualsScore
Feet ApartWe found that your leg separation angle was on average {{feet_apart_score}}° for your dive. 25 | This is rated as {{feet_apart_percentile}} percentile. 27 | {% if has_feet_apart_peaks %} 28 | 29 | Feet Apart GIF 30 | {%else%} 31 | There were no particular instances to show where your feet came apart. 32 | {%endif%} 33 | 35 | {{feet_apart_percentile_divided_by_ten}} 36 |
Height off platformYour jump was {{height_off_board_description}}, and was rated as {{height_off_board_percentile}} percentile. Here is the highest you jumped off the platform. 43 | 44 | 46 | {{height_off_board_percentile_divided_by_ten}} 47 |
Distance from platformYou were {{dist_from_board_percentile}} the platform. Here is where you came closest to the platform. 54 | 55 | 56 | 58 | {{dist_from_board_percentile_status}} 59 |
Somersault tightness 65 | We found that the tightness of your {{som_position_tightness_position}} was {{som_position_tightness_score}}° on average. 66 | This is rated as {{som_position_tightness_percentile}} percentile. Here are some examples of your position in the somersault. 67 | 70 | 71 | Somersault GIF 72 | 75 | 76 | Somersault GIF 77 | 80 | {{som_position_tightness_percentile_divided_by_ten}} 81 |
Knee straightness 88 | We found that your knees bent {{knee_bend_score}}° on average. 89 | This is rated as {{knee_bend_percentile}} percentile. 90 | 92 | {{knee_bend_percentile_divided_by_ten}} 93 |
99 | Twist Straightness 100 | 102 | We found that the tightness of your {{twist_position_tightness_position}} was {{twist_position_tightness_score}}° on average. 103 | This is rated as {{twist_position_tightness_percentile}} percentile. Here are some examples of your position in the somersault. 104 | 106 | 107 | Twist GIF 108 | 110 | {{twist_position_tightness_percentile_divided_by_ten}} 111 |
116 | Verticalness (over/under rotation) 117 | 119 | We found that you deviated from vertical by {{over_under_rotation_score}}°, which was the {{over_under_rotation_percentile}} percentile. 120 | 122 | 123 | Entry GIF 124 | 126 | {{over_under_rotation_percentile_divided_by_ten}} 127 |
Body straightness during entryThe straightness of your body during entry deviated by {{straightness_during_entry_score}}°, which was the {{straightness_during_entry_percentile}} percentile. {{straightness_during_entry_percentile_divided_by_ten}}
SplashYour splash was {{splash_description}} and rated as {{splash_percentile}} percentile. 138 | 139 | Splash GIF 140 | 142 | {{splash_percentile_divided_by_ten}} 143 |
147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /rule_based_programs/microprograms/dive_error_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | dive_error_functions.py 3 | Author: Lauren Okamoto 4 | 5 | AQA functions. 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | import math 11 | import cv2 12 | import sys, os 13 | from matplotlib import image 14 | from matplotlib import pyplot as plt 15 | from math import atan 16 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation 17 | from models.detectron2.detectors import get_platform_detector, get_splash_detector 18 | from detectron2.utils.visualizer import Visualizer 19 | 20 | ################## HELPER FUNCTIONS ################## 21 | def slope(x1, y1, x2, y2): 22 | if x1 == x2: 23 | return "undefined" 24 | return (y2-y1)/(x2-x1) 25 | 26 | # Function to find the angle between two lines 27 | def findAngle(M1, M2): 28 | vertical_line = False 29 | if M1 == "undefined": 30 | M1 = 0 31 | vertical_line = True 32 | if M2 == "undefined": 33 | M2 = 0 34 | vertical_line = True 35 | PI = 3.14159265 36 | angle = abs((M2 - M1) / (1 + M1 * M2)) 37 | ret = atan(angle) 38 | val = (ret * 180) / PI 39 | if vertical_line: 40 | return 90 - round(val,4) 41 | return (round(val, 4)) 42 | 43 | def find_which_side_board_on(output): 44 | pred_classes = output['instances'].pred_classes.cpu().numpy() 45 | platforms = np.where(pred_classes == 0)[0] 46 | scores = output['instances'].scores[platforms] 47 | if len(scores) == 0: 48 | return 49 | pred_masks = output['instances'].pred_masks[platforms] 50 | max_instance = torch.argmax(scores) 51 | pred_mask = np.array(pred_masks[max_instance].cpu()) 52 | for i in range(len(pred_mask[0])//2): 53 | if sum(pred_mask[:, i]) != 0: 54 | return "left" 55 | elif sum(pred_mask[:, len(pred_mask[0]) - i - 1]) != 0: 56 | return "right" 57 | return None 58 | 59 | def board_end(output, board_side=None): 60 | pred_classes = output['instances'].pred_classes.cpu().numpy() 61 | platforms = np.where(pred_classes == 0)[0] 62 | scores = output['instances'].scores[platforms] 63 | if len(scores) == 0: 64 | return 65 | pred_masks = output['instances'].pred_masks[platforms] 66 | max_instance = torch.argmax(scores) 67 | pred_mask = np.array(pred_masks[max_instance].cpu()) # splash instance with highest confidence 68 | # need to figure out whether springboard is on left or right side of frame, then need to find where the edge is 69 | if board_side is None: 70 | board_side = find_which_side_board_on(output) 71 | if board_side == "left": 72 | for i in range(len(pred_mask[0]) - 1, -1, -1): 73 | if sum(pred_mask[:, i]) != 0: 74 | trues = np.where(pred_mask[:, i])[0] 75 | return (i, min(trues)) 76 | if board_side == "right": 77 | for i in range(len(pred_mask[0])): 78 | if sum(pred_mask[:, i]) != 0: 79 | trues = np.where(pred_mask[:, i])[0] 80 | return (i, min(trues)) 81 | return None 82 | 83 | # Splash helper function 84 | def get_splash_pred_mask(output): 85 | pred_classes = output['instances'].pred_classes.cpu().numpy() 86 | splashes = np.where(pred_classes == 0)[0] 87 | scores = output['instances'].scores[splashes] 88 | if len(scores) == 0: 89 | return None 90 | pred_masks = output['instances'].pred_masks[splashes] 91 | max_instance = torch.argmax(scores) 92 | pred_mask = np.array(pred_masks[max_instance].cpu()) 93 | return pred_mask 94 | 95 | # Splash helper function that finds the splash instance with the highest percent confidence 96 | # and returns the 97 | def splash_area_percentage(output, pred_mask=None): 98 | if pred_mask is None: 99 | return 100 | # loops over pixels to get sum of splash pixels 101 | totalSum = 0 102 | for j in range(len(pred_mask)): 103 | totalSum += pred_mask[j].sum() 104 | # return percentage of image that is splash 105 | return totalSum/(len(pred_mask) * len(pred_mask[0])) 106 | 107 | def draw_two_coord(im, coord1, coord2, filename): 108 | print("hello, im in the drawing func") 109 | image = cv2.circle(im, (int(coord1[0]),int(coord1[1])), radius=5, color=(0, 0, 255), thickness=-1) 110 | image = cv2.circle(image, (int(coord2[0]),int(coord2[1])), radius=5, color=(0, 255, 0), thickness=-1) 111 | print(filename) 112 | if not cv2.imwrite(filename, image): 113 | print(filename) 114 | print("file failed to write") 115 | 116 | def draw_board_end_coord(im, coord): 117 | print("hello, im in the drawing func") 118 | image = cv2.circle(im, (int(coord[0]),int(coord[1])), radius=10, color=(0, 0, 255), thickness=-1) 119 | filename = os.path.join("./output/board_end/", d["file_name"][3:]) 120 | print(filename) 121 | if not cv2.imwrite(filename, image): 122 | print(filename) 123 | print("file failed to write") 124 | 125 | 126 | 127 | ####################################################### 128 | 129 | ################## DIVE ERROR MICROPROGRAMS ################## 130 | 131 | ## Feet apart error ## 132 | def applyFeetApartError(filepath, pose_pred=None, diver_detector=None, pose_model=None): 133 | if pose_pred is None and filepath != "": 134 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model) 135 | if pose_pred is not None: 136 | pose_pred = np.array(pose_pred)[0] 137 | average_knee = [np.mean((pose_pred[4][0], pose_pred[1][0])), np.mean((pose_pred[4][1], pose_pred[1][1]))] 138 | vector1 = [pose_pred[5][0] - average_knee[0], pose_pred[5][1] - average_knee[1]] 139 | vector2 = [pose_pred[0][0] - average_knee[0], pose_pred[0][1] - average_knee[1]] 140 | unit_vector_1 = vector1 / np.linalg.norm(vector1) 141 | unit_vector_2 = vector2 / np.linalg.norm(vector2) 142 | dot_product = np.dot(unit_vector_1, unit_vector_2) 143 | angle = math.degrees(np.arccos(dot_product)) 144 | return angle 145 | else: 146 | return None 147 | 148 | ## Calculates hip bend for somersault tightness & twist straightness errors ## 149 | def applyPositionTightnessError(filepath, pose_pred=None, diver_detector=None, pose_model=None): 150 | if pose_pred is None and filepath != "": 151 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model) 152 | if pose_pred is not None: 153 | pose_pred = np.array(pose_pred)[0] 154 | vector1 = [pose_pred[7][0] - pose_pred[2][0], pose_pred[7][1] - pose_pred[2][1]] 155 | vector2 = [pose_pred[1][0] - pose_pred[2][0], pose_pred[1][1] - pose_pred[2][1]] 156 | unit_vector_1 = vector1 / np.linalg.norm(vector1) 157 | unit_vector_2 = vector2 / np.linalg.norm(vector2) 158 | dot_product = np.dot(unit_vector_1, unit_vector_2) 159 | angle = math.degrees(np.arccos(dot_product)) 160 | return angle 161 | else: 162 | return None 163 | 164 | ## Distance from board error ## 165 | def calculate_distance_from_platform_for_one_frame(filepath, im=None, visualize=False, dive_folder_num="", platform_detector=None, pose_pred=None, diver_detector=None, pose_model=None, board_end_coord=None, board_side=None): 166 | if platform_detector is None: 167 | platform_detector = get_platform_detector() 168 | if pose_pred is None: 169 | diver_box, pose_pred = get_pose_estimation(filepath, image_bgr=im, diver_detector=diver_detector, pose_model=pose_model) 170 | if im is None and filepath != "": 171 | im = cv2.imread(filepath) 172 | if board_end_coord is None: 173 | outputs = platform_detector(im) 174 | board_end_coord = board_end(outputs, board_side=board_side) 175 | minDist = None 176 | if board_end_coord is not None and pose_pred is not None and len(board_end_coord) == 2: 177 | minDist = float('inf') 178 | for i in range(len(np.array(pose_pred)[0])): 179 | distance = math.dist(np.array(pose_pred)[0][i], np.array(board_end_coord)) 180 | if distance < minDist: 181 | minDist = distance 182 | minJoint = i 183 | if visualize: 184 | file_name = filepath.split('/')[-1] 185 | folder = "./output/data/distance_from_board/{}".format(dive_folder_num) 186 | out_filename = os.path.join(folder, file_name) 187 | if not os.path.exists(folder): 188 | os.makedirs(folder) 189 | draw_two_coord(im, board_end_coord, np.array(pose_pred)[0][minJoint], filename=out_filename) 190 | return minDist 191 | 192 | ## Over-rotation error ## 193 | def over_rotation(filepath, pose_pred=None, diver_detector=None, pose_model=None): 194 | if pose_pred is None and filepath != "": 195 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model) 196 | if pose_pred is not None: 197 | pose_pred = np.array(pose_pred)[0] 198 | vector1 = [(pose_pred[0][0] - pose_pred[2][0]), 0-(pose_pred[0][1] - pose_pred[2][1])] 199 | vector2 = [-1, 0] 200 | unit_vector_1 = vector1 / np.linalg.norm(vector1) 201 | unit_vector_2 = vector2 / np.linalg.norm(vector2) 202 | dot_product = np.dot(unit_vector_1, unit_vector_2) 203 | angle = math.degrees(np.arccos(dot_product)) 204 | return angle 205 | else: 206 | return None 207 | 208 | ## Splash size error ## 209 | def get_splash_from_one_frame(filepath, im=None, predictor=None, visualize=False, dive_folder_num=""): 210 | if predictor is None: 211 | predictor=get_splash_detector() 212 | if im is None: 213 | im = cv2.imread(filepath) 214 | outputs = predictor(im) 215 | pred_mask = get_splash_pred_mask(outputs) 216 | area = splash_area_percentage(outputs, pred_mask=pred_mask) 217 | if area is None: 218 | return None, None 219 | if visualize: 220 | pred_boxes = outputs['instances'].pred_boxes 221 | print("pred_boxes", pred_boxes) 222 | for box in pred_boxes: 223 | image = cv2.rectangle(im, (int(box[0]),int(box[1])), (int(box[2]),int(box[3])), color=(0, 0, 255), thickness=2) 224 | out_folder= "./output/data/splash/{}".format(dive_folder_num) 225 | if not os.path.exists(out_folder): 226 | os.makedirs(out_folder) 227 | filename = os.path.join(out_folder, filepath.split('/')[-1]) 228 | if not cv2.imwrite(filename, image): 229 | print('no image written to', filename) 230 | break 231 | 232 | return area.tolist(), pred_mask 233 | -------------------------------------------------------------------------------- /rule_based_programs/aqa_metaProgram_finediving.py: -------------------------------------------------------------------------------- 1 | """ 2 | aqa_metaProgram_finediving.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | from rule_based_programs.microprograms.dive_error_functions import * 7 | from rule_based_programs.microprograms.temporal_segmentation_functions import * 8 | from rule_based_programs.microprograms.dive_recognition_functions import * 9 | from rule_based_programs.scoring_functions import get_scale_factor 10 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector 11 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model 12 | import pickle 13 | import os, math 14 | import numpy as np 15 | import cv2 16 | import argparse 17 | 18 | def getDiveInfo_from_diveNum(diveNum): 19 | handstand = (diveNum[0] == '6') 20 | expected_som = int(diveNum[2]) 21 | if len(diveNum) == 5: 22 | expected_twists = int(diveNum[3]) 23 | else: 24 | expected_twists = 0 25 | if diveNum[0] == '1' or diveNum[0] == '3' or diveNum[:2] == '51' or diveNum[:2] == '53' or diveNum[:2] == '61' or diveNum[:2] == '63': 26 | back_facing = False 27 | else: 28 | back_facing = True 29 | if diveNum[0] == '1' or diveNum[:2] == '51' or diveNum[:2] == '61': 30 | expected_direction = 'front' 31 | elif diveNum[0] == '2' or diveNum[:2] == '52' or diveNum[:2] == '62': 32 | expected_direction = 'back' 33 | elif diveNum[0] == '3' or diveNum[:2] == '53' or diveNum[:2] == '63': 34 | expected_direction = 'reverse' 35 | elif diveNum[0] == '4': 36 | expected_direction = 'inward' 37 | if diveNum[-1] == 'b': 38 | position = 'pike' 39 | elif diveNum[-1] == 'c': 40 | position = 'tuck' 41 | else: 42 | position = 'free' 43 | return handstand, expected_som, expected_twists, back_facing, expected_direction, position 44 | 45 | def aqa_metaprogram_finediving(first_folder, second_folder, diveNum, board_side=None, platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None): 46 | handstand, expected_som, expected_twists, back_facing, expected_direction, position = getDiveInfo_from_diveNum(diveNum) 47 | dive_data = {} 48 | takeoff = [] 49 | twist = [] 50 | som = [] 51 | entry = [] 52 | distance_from_board = [] 53 | position_tightness = [] 54 | feet_apart = [] 55 | over_under_rotation = [] 56 | splash = [] 57 | pose_preds = [] 58 | diver_boxes = [] 59 | above_boards = [] 60 | on_boards = [] 61 | som_counts = [] 62 | twist_counts = [] 63 | board_end_coords = [] 64 | plat_outputs = [] 65 | board_sides = [] 66 | splash_pred_masks = [] 67 | above_board = True 68 | on_board = True 69 | if platform_detector is None: 70 | platform_detector = get_platform_detector() 71 | if splash_detector is None: 72 | splash_detector = get_splash_detector() 73 | if diver_detector is None: 74 | diver_detector = get_diver_detector() 75 | if pose_model is None: 76 | pose_model = get_pose_model() 77 | key = (first_folder, int(second_folder)) 78 | dive_folder_num = "{}_{}".format(first_folder, second_folder) 79 | directory = './FineDiving/datasets/FINADiving_MTL_256s/{}/{}/'.format(first_folder, second_folder) 80 | file_names = os.listdir(directory) 81 | 82 | ## find board_side 83 | if board_side is None: 84 | for i in range(len(file_names)): 85 | filepath = directory + file_names[i] 86 | if file_names[i][-4:] != ".jpg": 87 | continue 88 | im = cv2.imread(filepath) 89 | plat_output = platform_detector(im) 90 | board_side = find_which_side_board_on(plat_output) 91 | if board_side is not None: 92 | board_sides.append(board_side) 93 | dive_data['board_sides'] = board_sides 94 | board_sides.sort() 95 | board_side = board_sides[len(board_sides)//2] 96 | dive_data['board_side'] = board_side 97 | 98 | prev_pred = None 99 | som_prev_pred = None 100 | half_som_count=0 101 | petal_count = 0 102 | in_petal = False 103 | for i in range(len(file_names)): 104 | filepath = directory + file_names[i] 105 | if file_names[i][-4:] != ".jpg": 106 | continue 107 | diver_box, pose_pred = get_pose_estimation(filepath, diver_detector=diver_detector, pose_model=pose_model) 108 | diver_boxes.append(diver_box) 109 | pose_preds.append(pose_pred) 110 | 111 | calculated_half_som_count, skip = som_counter(pose_pred, prev_pose_pred=som_prev_pred, half_som_count=half_som_count, handstand=handstand) 112 | if not skip: 113 | som_prev_pred = pose_pred 114 | calculated_petal_count, calculated_in_petal = twist_counter(pose_pred, prev_pose_pred=prev_pred, in_petal=in_petal, petal_count=petal_count) 115 | im = cv2.imread(filepath) 116 | plat_output = platform_detector(im) 117 | plat_outputs.append(plat_output) 118 | board_end_coord = board_end(plat_output, board_side=board_side) 119 | board_end_coords.append(board_end_coord) 120 | if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]): 121 | above_board=False 122 | if on_board and detect_on_board(board_end_coord, board_side, pose_pred, handstand) is not None and not detect_on_board(board_end_coord, board_side, pose_pred, handstand): 123 | on_board = False 124 | if above_board: 125 | above_boards.append(1) 126 | else: 127 | above_boards.append(0) 128 | if on_board: 129 | on_boards.append(1) 130 | else: 131 | on_boards.append(0) 132 | calculated_takeoff = takeoff_microprogram_one_frame(filepath, above_board=above_board, on_board=on_board, pose_pred=pose_pred) 133 | calculated_twist = twist_microprogram_one_frame(filepath, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, diver_detector=diver_detector, pose_model=pose_model) 134 | calculated_som = somersault_microprogram_one_frame(filepath, pose_pred=pose_pred, on_board=on_board, expected_som=expected_som, half_som_count=half_som_count, expected_twists=expected_twists, petal_count=petal_count, diver_detector=diver_detector, pose_model=pose_model) 135 | calculated_entry = entry_microprogram_one_frame(filepath, above_board=above_board, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, splash_detector=splash_detector, visualize=False, dive_folder_num=dive_folder_num) 136 | if calculated_som == 1: 137 | half_som_count = calculated_half_som_count 138 | elif calculated_twist == 1: 139 | half_som_count = calculated_half_som_count 140 | petal_count = calculated_petal_count 141 | in_petal = calculated_in_petal 142 | dist = calculate_distance_from_platform_for_one_frame(filepath, visualize=False, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/ 143 | distance_from_board.append(dist) 144 | position_tightness.append(applyPositionTightnessError(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 145 | feet_apart.append(applyFeetApartError(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 146 | over_under_rotation.append(over_rotation(filepath, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 147 | splash_area, splash_pred_mask = get_splash_from_one_frame(filepath, predictor=splash_detector, visualize=False) 148 | splash.append(splash_area) 149 | splash_pred_masks.append(splash_pred_mask) 150 | takeoff.append(calculated_takeoff) 151 | twist.append(calculated_twist) 152 | som.append(calculated_som) 153 | entry.append(calculated_entry) 154 | som_counts.append(half_som_count) 155 | twist_counts.append(petal_count) 156 | prev_pred = pose_pred 157 | 158 | dive_data['pose_pred'] = pose_preds 159 | dive_data['takeoff'] = takeoff 160 | dive_data['twist'] = twist 161 | dive_data['som'] = som 162 | dive_data['entry'] = entry 163 | dive_data['distance_from_board'] = distance_from_board 164 | dive_data['position_tightness'] = position_tightness 165 | dive_data['feet_apart'] = feet_apart 166 | dive_data['over_under_rotation'] = over_under_rotation 167 | dive_data['splash'] = splash 168 | dive_data['above_boards'] = above_boards 169 | dive_data['on_boards'] = on_boards 170 | dive_data['som_counts'] = som_counts 171 | dive_data['twist_counts'] = twist_counts 172 | dive_data['board_end_coords'] = board_end_coords 173 | dive_data['diver_boxes'] = diver_boxes 174 | dive_data['splash_pred_masks'] = splash_pred_masks 175 | dive_data['board_side'] = board_side 176 | dive_data['is_handstand'] = handstand 177 | dive_data['direction'] = expected_direction 178 | dive_data['plat_outputs'] = plat_outputs 179 | return dive_data 180 | 181 | if __name__ == '__main__': 182 | # Set up command-line arguments 183 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.") 184 | new_parser.add_argument("FineDiving_key", type=str, nargs=2, help="key from FineDiving Dataset (e.g. 01 1)") 185 | meta_program_args = new_parser.parse_args() 186 | 187 | # Fine-grained annotations from FineDiving Dataset 188 | with open('FineDiving/Annotations/fine-grained_annotation_aqa.pkl', 'rb') as f: 189 | dive_annotation_data = pickle.load(f) 190 | 191 | key = tuple(meta_program_args.FineDiving_key) 192 | key = (key[0], int(key[1])) 193 | print(key) 194 | platform_detector = get_platform_detector() 195 | splash_detector = get_splash_detector() 196 | diver_detector = get_diver_detector() 197 | pose_model = get_pose_model() 198 | diveNum = dive_annotation_data[key][0] 199 | print(diveNum) 200 | dive_data = aqa_metaprogram_finediving(key[0], key[1], diveNum, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 201 | 202 | save_path = "./output/{}_{}.pkl".format(key[0], key[1]) 203 | with open(save_path, 'wb') as f: 204 | print("saving data into " + save_path) 205 | pickle.dump(dive_data, f) 206 | -------------------------------------------------------------------------------- /rule_based_programs/aqa_metaProgram.py: -------------------------------------------------------------------------------- 1 | """ 2 | aqa_metaProgram.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | from rule_based_programs.microprograms.dive_error_functions import * 7 | from rule_based_programs.microprograms.temporal_segmentation_functions import * 8 | from rule_based_programs.microprograms.dive_recognition_functions import * 9 | from rule_based_programs.scoring_functions import get_scale_factor 10 | from models.detectron2.detectors import get_platform_detector, get_diver_detector, get_splash_detector 11 | from models.pose_estimator.pose_estimator_model_setup import get_pose_estimation, get_pose_model 12 | import gradio as gr 13 | import pickle 14 | import os, sys, math 15 | import numpy as np 16 | import cv2 17 | import argparse 18 | 19 | def extract_frames(video_path): 20 | cap = cv2.VideoCapture(video_path) 21 | # Check if the video file is opened successfully 22 | if not cap.isOpened(): 23 | print("Error: Couldn't open video file.") 24 | exit() 25 | frame_skip = 1 26 | # a variable to keep track of the frame to be saved 27 | frame_count = 0 28 | frames = [] 29 | i = 0 30 | while True: 31 | ret, frame = cap.read() 32 | if not ret: 33 | break 34 | if i > frame_skip - 1: 35 | frame_count += 1 36 | frame = cv2.resize(frame, (455, 256)) # resize takes argument (width, height) 37 | frames.append(frame) 38 | i = 0 39 | continue 40 | i += 1 41 | cap.release() 42 | return frames 43 | 44 | def getDiveInfo_from_diveNum(diveNum): 45 | handstand = (diveNum[0] == '6') 46 | expected_som = int(diveNum[2]) 47 | if len(diveNum) == 5: 48 | expected_twists = int(diveNum[3]) 49 | else: 50 | expected_twists = 0 51 | if diveNum[0] == '1' or diveNum[0] == '3' or diveNum[:2] == '51' or diveNum[:2] == '53' or diveNum[:2] == '61' or diveNum[:2] == '63': 52 | back_facing = False 53 | else: 54 | back_facing = True 55 | if diveNum[0] == '1' or diveNum[:2] == '51' or diveNum[:2] == '61': 56 | expected_direction = 'front' 57 | elif diveNum[0] == '2' or diveNum[:2] == '52' or diveNum[:2] == '62': 58 | expected_direction = 'back' 59 | elif diveNum[0] == '3' or diveNum[:2] == '53' or diveNum[:2] == '63': 60 | expected_direction = 'reverse' 61 | elif diveNum[0] == '4': 62 | expected_direction = 'inward' 63 | if diveNum[-1] == 'b': 64 | position = 'pike' 65 | elif diveNum[-1] == 'c': 66 | position = 'tuck' 67 | else: 68 | position = 'free' 69 | return handstand, expected_som, expected_twists, back_facing, expected_direction, position 70 | 71 | def getDiveInfo_from_symbols(frames, dive_data=None, platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None): 72 | print("Getting dive info from symbols...") 73 | if dive_data is None: 74 | print("somethings not getting passed in properly") 75 | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 76 | 77 | # get above_boards, on_boards, and position_tightness 78 | above_board = True 79 | on_board = True 80 | above_boards = [] 81 | on_boards = [] 82 | position_tightness = [] 83 | distances = [] 84 | prev_board_coord = None 85 | for i in range(len(dive_data['pose_pred'])): 86 | pose_pred = dive_data['pose_pred'][i] 87 | board_end_coord = dive_data['board_end_coords'][i] 88 | if board_end_coord is not None and prev_board_coord is not None: 89 | distances.append(math.dist(board_end_coord, prev_board_coord)) 90 | if math.dist(board_end_coord, prev_board_coord) > 150: 91 | position_tightness.append(applyPositionTightnessError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 92 | if above_board: 93 | above_boards.append(1) 94 | else: 95 | above_boards.append(0) 96 | if on_board: 97 | on_boards.append(1) 98 | else: 99 | on_boards.append(0) 100 | continue 101 | if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]): 102 | above_board=False 103 | if on_board: 104 | handstand = is_handstand(dive_data) 105 | calculate_on_board = detect_on_board(board_end_coord, dive_data['board_side'], pose_pred, handstand) 106 | if calculate_on_board is not None and not calculate_on_board: 107 | on_board = False 108 | if above_board: 109 | above_boards.append(1) 110 | else: 111 | above_boards.append(0) 112 | if on_board: 113 | on_boards.append(1) 114 | else: 115 | on_boards.append(0) 116 | prev_board_coord = board_end_coord 117 | position_tightness.append(applyPositionTightnessError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 118 | dive_data['on_boards'] = on_boards 119 | dive_data['above_boards'] = above_boards 120 | dive_data['position_tightness'] = position_tightness 121 | 122 | ## handstand and som_count## 123 | expected_som, handstand = som_counter_full_dive(dive_data) 124 | 125 | ## twist_count 126 | expected_twists = twist_counter_full_dive(dive_data) 127 | 128 | ## direction: front, back, reverse, inward 129 | expected_direction = get_direction(dive_data) 130 | 131 | return handstand, expected_som, expected_twists, expected_direction, dive_data 132 | 133 | 134 | def abstractSymbols(frames, progress=gr.Progress(), platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None): 135 | print("Abstracting symbols...") 136 | splashes = [] 137 | pose_preds = [] 138 | board_sides = [] 139 | plat_outputs = [] 140 | diver_boxes = [] 141 | splash_pred_masks = [] 142 | if platform_detector is None: 143 | platform_detector = get_platform_detector() 144 | if splash_detector is None: 145 | splash_detector = get_splash_detector() 146 | if diver_detector is None: 147 | diver_detector = get_diver_detector() 148 | if pose_model is None: 149 | pose_model = get_pose_model() 150 | num_frames = len(frames) 151 | i = 0 152 | for frame in frames: 153 | progress(i/num_frames, desc="Abstracting Symbols") 154 | plat_output = platform_detector(frame) 155 | plat_outputs.append(plat_output) 156 | board_side = find_which_side_board_on(plat_output) 157 | if board_side is not None: 158 | board_sides.append(board_side) 159 | diver_box, pose_pred = get_pose_estimation(filepath="", image_bgr=frame, diver_detector=diver_detector, pose_model=pose_model) 160 | pose_preds.append(pose_pred) 161 | diver_boxes.append(diver_box) 162 | splash_area, splash_pred_mask = get_splash_from_one_frame(filepath="", im=frame, predictor=splash_detector, visualize=False) 163 | splash_pred_masks.append(splash_pred_mask) 164 | splashes.append(splash_area) 165 | i+=1 166 | dive_data = {} 167 | dive_data['plat_outputs'] = plat_outputs 168 | dive_data['pose_pred'] = pose_preds 169 | dive_data['splash'] = splashes 170 | dive_data['splash_pred_masks'] = splash_pred_masks 171 | dive_data['board_sides'] = board_sides 172 | board_sides.sort() 173 | board_side = board_sides[len(board_sides)//2] 174 | dive_data['board_side'] = board_side 175 | dive_data['diver_boxes'] = diver_boxes 176 | 177 | # get board_end_coords 178 | board_end_coords = [] 179 | for plat_output in dive_data['plat_outputs']: 180 | board_end_coord = board_end(plat_output, board_side=dive_data['board_side']) 181 | board_end_coords.append(board_end_coord) 182 | dive_data['board_end_coords'] = board_end_coords 183 | 184 | return dive_data 185 | 186 | def aqa_metaprogram(frames, dive_data, progress=gr.Progress(), diveNum="", board_side=None, platform_detector=None, splash_detector=None, diver_detector=None, pose_model=None): 187 | print("AQA Metaprogram...") 188 | if len(frames) != len(dive_data['pose_pred']): 189 | raise gr.Error("Abstract Symbols first!") 190 | if diveNum != "": 191 | dive_num_given = True 192 | handstand, expected_som, expected_twists, back_facing, expected_direction, position = getDiveInfo_from_diveNum(diveNum) 193 | else: 194 | dive_num_given = False 195 | handstand, expected_som, expected_twists, expected_direction, dive_data = getDiveInfo_from_symbols(frames, dive_data=dive_data, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 196 | 197 | if not dive_num_given: 198 | above_boards = dive_data['above_boards'] 199 | on_boards = dive_data['on_boards'] 200 | position_tightness = dive_data['position_tightness'] 201 | board_end_coords = dive_data['board_end_coords'] 202 | else: 203 | above_board = True 204 | on_board = True 205 | above_boards = [] 206 | on_boards = [] 207 | board_end_coords = [] 208 | position_tightness = [] 209 | splash = dive_data['splash'] 210 | diver_boxes = dive_data['diver_boxes'] 211 | board_side = dive_data['board_side'] 212 | pose_preds = dive_data['pose_pred'] 213 | takeoff = [] 214 | twist = [] 215 | som = [] 216 | entry = [] 217 | distance_from_board = [] 218 | feet_apart = [] 219 | over_under_rotation = [] 220 | som_counts = [] 221 | twist_counts = [] 222 | 223 | if platform_detector is None: 224 | platform_detector = get_platform_detector() 225 | if splash_detector is None: 226 | splash_detector = get_splash_detector() 227 | if diver_detector is None: 228 | diver_detector = get_diver_detector() 229 | if pose_model is None: 230 | pose_model = get_pose_model() 231 | 232 | prev_pred = None 233 | som_prev_pred = None 234 | half_som_count=0 235 | petal_count = 0 236 | in_petal = False 237 | num_frames = len(frames) 238 | for i in range(num_frames): 239 | progress(i/num_frames, desc="Calculating Dive Errors") 240 | pose_pred = pose_preds[i] 241 | calculated_half_som_count, skip = som_counter(pose_pred, prev_pose_pred=som_prev_pred, half_som_count=half_som_count, handstand=handstand) 242 | if not skip: 243 | som_prev_pred = pose_pred 244 | calculated_petal_count, calculated_in_petal = twist_counter(pose_pred, prev_pose_pred=prev_pred, in_petal=in_petal, petal_count=petal_count) 245 | if dive_num_given: 246 | outputs = platform_detector(frames[i]) 247 | board_end_coord = board_end(outputs, board_side=board_side) 248 | board_end_coords.append(board_end_coord) 249 | if above_board and not on_board and board_end_coord is not None and pose_pred is not None and np.array(pose_pred)[0][2][1] > int(board_end_coord[1]): 250 | above_board=False 251 | if on_board and detect_on_board(board_end_coord, board_side, pose_pred, handstand) is not None and not detect_on_board(board_end_coord, board_side, pose_pred, handstand): 252 | on_board = False 253 | if above_board: 254 | above_boards.append(1) 255 | else: 256 | above_boards.append(0) 257 | if on_board: 258 | on_boards.append(1) 259 | else: 260 | on_boards.append(0) 261 | else: 262 | board_end_coord = board_end_coords[i] 263 | above_board = (above_boards[i] == 1) 264 | on_board = (on_boards[i] == 1) 265 | calculated_takeoff = takeoff_microprogram_one_frame(filepath="", above_board=above_board, on_board=on_board, pose_pred=pose_pred) 266 | calculated_twist = twist_microprogram_one_frame(filepath="", on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, diver_detector=diver_detector, pose_model=pose_model) 267 | calculated_som = somersault_microprogram_one_frame(filepath="", pose_pred=pose_pred, on_board=on_board, expected_som=expected_som, half_som_count=half_som_count, expected_twists=expected_twists, petal_count=petal_count, diver_detector=diver_detector, pose_model=pose_model) 268 | calculated_entry = entry_microprogram_one_frame(filepath="", frame=frames[i], above_board=above_board, on_board=on_board, pose_pred=pose_pred, expected_twists=expected_twists, petal_count=petal_count, expected_som=expected_som, half_som_count=half_som_count, splash_detector=splash_detector, visualize=False) 269 | if calculated_som == 1: 270 | half_som_count = calculated_half_som_count 271 | elif calculated_twist == 1: 272 | half_som_count = calculated_half_som_count 273 | petal_count = calculated_petal_count 274 | in_petal = calculated_in_petal 275 | # distance from board 276 | dist = calculate_distance_from_platform_for_one_frame(filepath="", im=frames[i], visualize=False, pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model, board_end_coord=board_end_coord, platform_detector=platform_detector) # saves photo to ./output/data/distance_from_board/ 277 | distance_from_board.append(dist) 278 | if dive_num_given: 279 | position_tightness.append(applyPositionTightnessError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 280 | feet_apart.append(applyFeetApartError(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 281 | over_under_rotation.append(over_rotation(filepath="", pose_pred=pose_pred, diver_detector=diver_detector, pose_model=pose_model)) 282 | takeoff.append(calculated_takeoff) 283 | twist.append(calculated_twist) 284 | som.append(calculated_som) 285 | entry.append(calculated_entry) 286 | som_counts.append(half_som_count) 287 | twist_counts.append(petal_count) 288 | prev_pred = pose_pred 289 | 290 | dive_data['takeoff'] = takeoff 291 | dive_data['twist'] = twist 292 | dive_data['som'] = som 293 | dive_data['entry'] = entry 294 | dive_data['distance_from_board'] = distance_from_board 295 | dive_data['position_tightness'] = position_tightness 296 | dive_data['feet_apart'] = feet_apart 297 | dive_data['over_under_rotation'] = over_under_rotation 298 | dive_data['above_boards'] = above_boards 299 | dive_data['on_boards'] = on_boards 300 | dive_data['som_counts'] = som_counts 301 | dive_data['twist_counts'] = twist_counts 302 | dive_data['board_end_coords'] = board_end_coords 303 | dive_data['is_handstand'] = handstand 304 | dive_data['direction'] = expected_direction 305 | return dive_data 306 | 307 | if __name__ == '__main__': 308 | # Set up command-line arguments 309 | new_parser = argparse.ArgumentParser(description="Extract dive data to be used for scoring.") 310 | new_parser.add_argument("video_path", type=str, help="Path to dive video (mp4 format).") 311 | meta_program_args = new_parser.parse_args() 312 | 313 | video_path = meta_program_args.video_path 314 | frames = extract_frames(video_path) 315 | platform_detector = get_platform_detector() 316 | splash_detector = get_splash_detector() 317 | diver_detector = get_diver_detector() 318 | pose_model = get_pose_model() 319 | 320 | dive_data = abstractSymbols(frames, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 321 | dive_data = aqa_metaprogram(frames, dive_data, platform_detector=platform_detector, splash_detector=splash_detector, diver_detector=diver_detector, pose_model=pose_model) 322 | 323 | save_path = "./output/{}.pkl".format("".join(video_path.split('.')[:-1])) 324 | with open(save_path, 'wb') as f: 325 | print("saving data into " + save_path) 326 | pickle.dump(dive_data, f) 327 | -------------------------------------------------------------------------------- /rule_based_programs/microprograms/dive_recognition_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | dive_recognition_functions.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | import pickle 7 | import numpy as np 8 | import os 9 | import math 10 | from scipy.signal import find_peaks 11 | from matplotlib import pyplot as plt 12 | 13 | def get_scale_factor(dive_data): 14 | # distance between thorax and pelvis 15 | distances = [] 16 | for pose_pred in dive_data['pose_pred']: 17 | if pose_pred is not None: 18 | distances.append(math.dist(pose_pred[0][6], pose_pred[0][7])) 19 | distances.sort() 20 | return np.median(distances) 21 | 22 | def find_angle(vector1, vector2): 23 | unit_vector_1 = vector1 / np.linalg.norm(vector1) 24 | unit_vector_2 = vector2 / np.linalg.norm(vector2) 25 | dot_product = np.dot(unit_vector_1, unit_vector_2) 26 | angle = math.degrees(np.arccos(dot_product)) 27 | return angle 28 | 29 | def is_back_facing(dive_data, board_side): 30 | directions = [] 31 | for i in range(len(dive_data['pose_pred'])): 32 | pose_pred = dive_data['pose_pred'][i] 33 | if pose_pred is None or dive_data['above_boards'][i] == 0: 34 | continue 35 | pose_pred = pose_pred[0] 36 | 37 | ## left knee bend ### 38 | l_knee = pose_pred[4] 39 | l_ankle = pose_pred[5] 40 | l_hip = pose_pred[3] 41 | l_knee_ankle = [l_ankle[0] - l_knee[0], 0-(l_ankle[1] - l_knee[1])] 42 | l_knee_hip = [l_hip[0] - l_knee[0], 0-(l_hip[1] - l_knee[1])] 43 | l_direction = rotation_direction(l_knee_hip, l_knee_ankle) 44 | 45 | ## right knee bend ### 46 | r_knee = pose_pred[1] 47 | r_ankle = pose_pred[0] 48 | r_hip = pose_pred[2] 49 | r_knee_ankle = [r_ankle[0] - r_knee[0], 0-(r_ankle[1] - r_knee[1])] 50 | r_knee_hip = [r_hip[0] - r_knee[0], 0-(r_hip[1] - r_knee[1])] 51 | r_direction = rotation_direction(r_knee_hip, r_knee_ankle) 52 | if l_direction == r_direction and l_direction != 0 and board_side == 'left': 53 | # we're looking for more clockwise 54 | return l_direction < 0 55 | elif l_direction == r_direction and l_direction != 0: 56 | # we're looking for more counterclockwise 57 | return l_direction > 0 58 | return False 59 | 60 | def rotation_direction(vector1, vector2, threshold=0.4): 61 | # Calculate the determinant to determine rotation direction 62 | determinant = vector1[0] * vector2[1] - vector1[1] * vector2[0] 63 | mag1= np.linalg.norm(vector1) 64 | mag2= np.linalg.norm(vector2) 65 | norm_det = determinant/(mag1*mag2) 66 | if norm_det > threshold: 67 | # "counterclockwise" 68 | return 1 69 | elif norm_det < 0-threshold: 70 | # "clockwise" 71 | return -1 72 | else: 73 | # "not determinent" 74 | return 0 75 | 76 | # returns None if either pose_pred or board_end_coord is None 77 | # returns True if diver is on board, returns False if diver is off board 78 | def detect_on_board(board_end_coord, board_side, pose_pred, handstand): 79 | if pose_pred is None: 80 | return 81 | if board_end_coord is None: 82 | return 83 | if board_side == 'left': 84 | # if right of board end 85 | if np.array(pose_pred)[0][2][0] > int(board_end_coord[0]): 86 | if handstand: 87 | distance = math.dist(np.array(pose_pred)[0][15], board_end_coord) < math.dist(np.array(pose_pred)[0][14], np.array(pose_pred)[0][15]) * 1.5 88 | else: 89 | distance = math.dist(np.array(pose_pred)[0][5], board_end_coord) < math.dist(np.array(pose_pred)[0][5], np.array(pose_pred)[0][4]) * 1.5 90 | 91 | return distance 92 | # if left of board end 93 | else: 94 | return True 95 | else: 96 | # if right of board end 97 | if np.array(pose_pred)[0][2][0] > int(board_end_coord[0]): 98 | return True 99 | # if left of board end 100 | else: 101 | if handstand: 102 | distance = math.dist(np.array(pose_pred)[0][10], board_end_coord) < math.dist(np.array(pose_pred)[0][11], np.array(pose_pred)[0][10]) * 1.5 103 | else: 104 | distance = math.dist(np.array(pose_pred)[0][0], board_end_coord) < math.dist(np.array(pose_pred)[0][1], np.array(pose_pred)[0][0]) * 1.5 105 | return distance 106 | 107 | def find_position(dive_data): 108 | angles = [] 109 | three_in_a_row = 0 110 | for i in range(1, len(dive_data['pose_pred'])): 111 | pose_pred = dive_data['pose_pred'][i] 112 | if pose_pred is None or dive_data['som'][i]==0: 113 | continue 114 | pose_pred = pose_pred[0] 115 | l_knee = pose_pred[4] 116 | l_ankle = pose_pred[5] 117 | l_hip = pose_pred[3] 118 | l_knee_ankle = [l_ankle[0] - l_knee[0], 0-(l_ankle[1] - l_knee[1])] 119 | l_knee_hip = [l_hip[0] - l_knee[0], 0-(l_hip[1] - l_knee[1])] 120 | angle = find_angle(l_knee_ankle, l_knee_hip) 121 | angles.append(angle) 122 | if angle < 70: 123 | three_in_a_row += 1 124 | if three_in_a_row >=3: 125 | return 'tuck' 126 | else: 127 | three_in_a_row =0 128 | if twist_counter_full_dive(dive_data) > 0 and som_counter_full_dive(dive_data)[0] < 5: 129 | return 'free' 130 | return 'pike' 131 | 132 | 133 | def distance_point_to_line_segment(px, py, x1, y1, x2, y2): 134 | # Calculate the squared distance from point (px, py) to the line segment [(x1, y1), (x2, y2)] 135 | def sqr_distance_point_to_segment(): 136 | line_length_sq = (x2 - x1)**2 + (y2 - y1)**2 137 | if line_length_sq == 0: 138 | return (px - x1)**2 + (py - y1)**2 139 | t = max(0, min(1, ((px - x1) * (x2 - x1) + (py - y1) * (y2 - y1)) / line_length_sq)) 140 | return ((px - (x1 + t * (x2 - x1)))**2 + (py - (y1 + t * (y2 - y1)))**2) 141 | 142 | # Calculate the closest point on the line segment to the given point (px, py) 143 | def closest_point_on_line_segment(): 144 | line_length_sq = (x2 - x1)**2 + (y2 - y1)**2 145 | if line_length_sq == 0: 146 | return x1, y1 147 | t = max(0, min(1, ((px - x1) * (x2 - x1) + (py - y1) * (y2 - y1)) / line_length_sq)) 148 | closest_x = x1 + t * (x2 - x1) 149 | closest_y = y1 + t * (y2 - y1) 150 | return closest_x, closest_y 151 | 152 | closest_point = closest_point_on_line_segment() 153 | distance = math.sqrt(sqr_distance_point_to_segment()) 154 | 155 | return closest_point, distance 156 | 157 | def min_distance_from_line_to_circle(line_start, line_end, circle_center, circle_radius): 158 | closest_point, distance = distance_point_to_line_segment(circle_center[0], circle_center[1], 159 | line_start[0], line_start[1], 160 | line_end[0], line_end[1]) 161 | 162 | min_distance = max(0, distance - circle_radius) 163 | return min_distance 164 | 165 | def twister(pose_pred, prev_pose_pred=None, in_petal=False, petal_count=0, outer=10, inner=9, valid=17, middle=0.5): 166 | if pose_pred is None: 167 | return petal_count, in_petal 168 | min_dist = 0 169 | pose_pred = pose_pred[0] 170 | vector1 = [pose_pred[2][0] - pose_pred[3][0], 0-(pose_pred[2][1] - pose_pred[3][1])] 171 | if prev_pose_pred is not None: 172 | prev_pose_pred = prev_pose_pred[0] 173 | prev_pose_pred = [prev_pose_pred[2][0] - prev_pose_pred[3][0], 0-(prev_pose_pred[2][1] - prev_pose_pred[3][1])] 174 | min_dist = min_distance_from_line_to_circle(prev_pose_pred, vector1, (0, 0), middle) 175 | if np.linalg.norm(vector1) > valid: 176 | return petal_count, in_petal 177 | if min_dist is not None and in_petal and np.linalg.norm(vector1) > outer and min_dist == 0: 178 | petal_count += 1 179 | elif not in_petal and np.linalg.norm(vector1) > outer: 180 | in_petal = True 181 | petal_count += 1 182 | elif in_petal and np.linalg.norm(vector1) < inner: 183 | in_petal = False 184 | return petal_count, in_petal 185 | 186 | 187 | def twist_counter(pose_pred, prev_pose_pred=None, in_petal=False, petal_count=0): 188 | valid = 17 189 | outer = 10 190 | inner = 9 191 | if pose_pred is None: 192 | return petal_count, in_petal 193 | min_dist = 0 194 | pose_pred = pose_pred[0] 195 | vector1 = [pose_pred[2][0] - pose_pred[3][0], 0-(pose_pred[2][1] - pose_pred[3][1])] 196 | if prev_pose_pred is not None: 197 | prev_pose_pred = prev_pose_pred[0] 198 | prev_pose_pred = [prev_pose_pred[2][0] - prev_pose_pred[3][0], 0-(prev_pose_pred[2][1] - prev_pose_pred[3][1])] 199 | min_dist = min_distance_from_line_to_circle(prev_pose_pred, vector1, (0, 0), 0.5) 200 | if np.linalg.norm(vector1) > valid: 201 | return petal_count, in_petal 202 | if min_dist is not None and in_petal and np.linalg.norm(vector1) > outer and min_dist == 0: 203 | petal_count += 1 204 | elif not in_petal and np.linalg.norm(vector1) > outer: 205 | in_petal = True 206 | elif in_petal and np.linalg.norm(vector1) < inner: 207 | in_petal = False 208 | petal_count += 1 209 | return petal_count, in_petal 210 | 211 | 212 | def twist_counter_full_dive(dive_data, visualize=False): 213 | dist_hip = [] 214 | prev_pose_pred = None 215 | in_petal=False 216 | petal_count=0 217 | scale = get_scale_factor(dive_data) 218 | valid = scale / 1.5 219 | outer = scale / 3.2 220 | inner = scale / 3.4 221 | middle = 0.5 222 | next_next_pose_pred = dive_data['pose_pred'][4] 223 | for i in range(len(dive_data['pose_pred'])): 224 | pose_pred = dive_data['pose_pred'][i] 225 | if i < len(dive_data['pose_pred']) - 1: 226 | next_pose_pred = dive_data['pose_pred'][i + 1] 227 | if i < len(dive_data['pose_pred']) - 4 and next_next_pose_pred is not None: 228 | next_next_pose_pred = dive_data['pose_pred'][i + 4] 229 | if pose_pred is None or dive_data['on_boards'][i] == 1 or dive_data['position_tightness'][i] <= 80 or next_next_pose_pred is None: 230 | continue 231 | petal_count, in_petal = twister(pose_pred, prev_pose_pred=prev_pose_pred, in_petal=in_petal, petal_count=petal_count, outer=outer, inner=inner, middle=middle, valid=valid) 232 | prev_pose_pred = pose_pred 233 | if visualize: 234 | pose_pred = pose_pred[0] 235 | dist_hip.append([pose_pred[2][0] - pose_pred[3][0], 0-(pose_pred[2][1] - pose_pred[3][1])]) 236 | if visualize: 237 | dist_hip = np.array(dist_hip) 238 | plt.plot(dist_hip[:, 0], dist_hip[:, 1], label="right-to-left hip") 239 | circle1 = plt.Circle((0, 0), outer, fill=False) 240 | plt.gca().add_patch(circle1) 241 | circle2 = plt.Circle((0, 0), inner, fill=False) 242 | plt.gca().add_patch(circle2) 243 | circle3 = plt.Circle((0, 0), valid, fill=False) 244 | plt.gca().add_patch(circle3) 245 | plt.legend() 246 | plt.show() 247 | return petal_count 248 | 249 | def rotation_direction_som(vector1, vector2, threshold=0.4): 250 | # Calculate the determinant to determine rotation direction 251 | determinant = vector1[0] * vector2[1] - vector1[1] * vector2[0] 252 | mag1= np.linalg.norm(vector1) 253 | mag2= np.linalg.norm(vector2) 254 | norm_det = determinant/(mag1*mag2) 255 | theta = np.arcsin(norm_det) 256 | return math.degrees(theta) 257 | def is_handstand(dive_data): 258 | first_frame_pose_pred = dive_data['pose_pred'][0] 259 | handstand = False 260 | if first_frame_pose_pred[0][6][1] < first_frame_pose_pred[0][7][1]: 261 | handstand = True 262 | return handstand 263 | 264 | 265 | def som_counter(pose_pred=None, prev_pose_pred=None, half_som_count=0, handstand=False): 266 | if pose_pred is None: 267 | return half_som_count, True 268 | pose_pred = pose_pred[0] 269 | vector1 = [pose_pred[7][0] - pose_pred[6][0], 0-(pose_pred[7][1] - pose_pred[6][1])] # flip y axis 270 | if (not handstand and half_som_count % 2 == 0) or (handstand and half_som_count %2 == 1): 271 | vector2 = [0, -1] 272 | else: 273 | vector2 = [0, 1] 274 | unit_vector_1 = vector1 / np.linalg.norm(vector1) 275 | unit_vector_2 = vector2 / np.linalg.norm(vector2) 276 | dot_product = np.dot(unit_vector_1, unit_vector_2) 277 | current_angle = math.degrees(np.arccos(dot_product)) 278 | if prev_pose_pred is not None: 279 | prev_pose_pred = prev_pose_pred[0] 280 | prev_vector = [prev_pose_pred[7][0] - prev_pose_pred[6][0], 0-(prev_pose_pred[7][1] - prev_pose_pred[6][1])] # flip y axis 281 | prev_unit_vector = prev_vector / np.linalg.norm(prev_vector) 282 | prev_angle_diff = math.degrees(np.arccos(np.dot(unit_vector_1, prev_unit_vector))) 283 | if prev_angle_diff > 115: 284 | return half_som_count, True 285 | if current_angle <= 80: 286 | half_som_count += 1 287 | return half_som_count, False 288 | 289 | 290 | def som_counter_full_dive(dive_data, visualize=False): 291 | half_som_count = 0 292 | dist_body = [] 293 | handstand = is_handstand(dive_data) 294 | next_next_pose_pred = dive_data['pose_pred'][2] 295 | prev = None 296 | for i in range(len(dive_data['pose_pred'])): 297 | pose_pred = dive_data['pose_pred'][i] 298 | if i < len(dive_data['pose_pred']) - 2 and next_next_pose_pred is not None: 299 | next_next_pose_pred = dive_data['pose_pred'][i + 2] 300 | if pose_pred is None or next_next_pose_pred is None or dive_data['on_boards'][i] == 1: 301 | continue 302 | pose_pred = pose_pred[0] 303 | vector1 = [pose_pred[7][0] - pose_pred[6][0], 0-(pose_pred[7][1] - pose_pred[6][1])] 304 | if (not handstand and half_som_count % 2 == 0) or (handstand and half_som_count % 2 == 1): 305 | vector2 = [0, -1] 306 | else: 307 | vector2 = [0, 1] 308 | sensitivity = 115 309 | if prev is not None and find_angle(vector1, prev) > sensitivity: 310 | continue 311 | is_clockwise = is_rotating_clockwise(dive_data) 312 | if prev is not None and ((is_clockwise and rotation_direction_som(vector1, prev)<0) or (not is_clockwise and rotation_direction_som(vector1, prev)>0)): 313 | continue 314 | angle = find_angle(vector1, vector2) 315 | if angle <= 75: 316 | half_som_count += 1 317 | if visualize: 318 | dist_body.append([pose_pred[7][0] - pose_pred[6][0], 0-(pose_pred[7][1] - pose_pred[6][1])]) 319 | prev = vector1 320 | if visualize: 321 | dist_body = np.array(dist_body) 322 | plt.plot(dist_body[:, 0], dist_body[:, 1], label="pelvis-to-thorax") 323 | plt.xlabel("x-coord") 324 | plt.ylabel("y-coord") 325 | plt.legend() 326 | plt.show() 327 | return half_som_count, handstand 328 | 329 | def getDiveInfo(diveNum): 330 | handstand = (diveNum[0] == '6') 331 | expected_som = int(diveNum[2]) 332 | if len(diveNum) == 5: 333 | expected_twists = int(diveNum[3]) 334 | else: 335 | expected_twists = 0 336 | if diveNum[0] == '1' or diveNum[0] == '3' or diveNum[:2] == '51' or diveNum[:2] == '53' or diveNum[:2] == '61' or diveNum[:2] == '63': 337 | back_facing = False 338 | else: 339 | back_facing = True 340 | if diveNum[0] == '1' or diveNum[:2] == '51' or diveNum[:2] == '61': 341 | expected_direction = 'front' 342 | elif diveNum[0] == '2' or diveNum[:2] == '52' or diveNum[:2] == '62': 343 | expected_direction = 'back' 344 | elif diveNum[0] == '3' or diveNum[:2] == '53' or diveNum[:2] == '63': 345 | expected_direction = 'reverse' 346 | elif diveNum[0] == '4': 347 | expected_direction = 'inward' 348 | if diveNum[-1] == 'b': 349 | position = 'pike' 350 | elif diveNum[-1] == 'c': 351 | position = 'tuck' 352 | else: 353 | position = 'free' 354 | return handstand, expected_som, expected_twists, back_facing, expected_direction, position 355 | 356 | def get_direction(dive_data): 357 | clockwise = is_rotating_clockwise(dive_data) 358 | board_side = dive_data['board_side'] 359 | if board_side == "right": 360 | back_facing = is_back_facing(dive_data, 'right') 361 | if back_facing and clockwise: 362 | direction = 'inward' 363 | elif back_facing and not clockwise: 364 | direction = 'back' 365 | elif not back_facing and clockwise: 366 | direction = 'reverse' 367 | elif not back_facing and not clockwise: 368 | direction = 'front' 369 | else: 370 | back_facing = is_back_facing(dive_data, 'left') 371 | if back_facing and clockwise: 372 | direction = 'back' 373 | elif back_facing and not clockwise: 374 | direction = 'inward' 375 | elif not back_facing and clockwise: 376 | direction = 'front' 377 | elif not back_facing and not clockwise: 378 | direction = 'reverse' 379 | return direction 380 | 381 | def is_rotating_clockwise(dive_data): 382 | directions = [] 383 | for i in range(1, len(dive_data['pose_pred'])): 384 | if dive_data['pose_pred'][i] is None or dive_data['pose_pred'][i-1] is None: 385 | continue 386 | if dive_data['on_boards'][i] == 0: 387 | prev_pose_pred_hip = dive_data['pose_pred'][i-1][0][3] 388 | curr_pose_pred_hip = dive_data['pose_pred'][i][0][3] 389 | prev_pose_pred_knee = dive_data['pose_pred'][i-1][0][4] 390 | curr_pose_pred_knee = dive_data['pose_pred'][i][0][4] 391 | prev_hip_knee = [prev_pose_pred_knee[0] - prev_pose_pred_hip[0], 0-(prev_pose_pred_knee[1] - prev_pose_pred_hip[1])] 392 | curr_hip_knee = [curr_pose_pred_knee[0] - curr_pose_pred_hip[0], 0-(curr_pose_pred_knee[1] - curr_pose_pred_hip[1])] 393 | direction = rotation_direction(prev_hip_knee, curr_hip_knee, threshold=0) 394 | directions.append(direction) 395 | return np.sum(directions) < 0 396 | -------------------------------------------------------------------------------- /rule_based_programs/scoring_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | scoring_functions.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | import pickle 7 | import numpy as np 8 | import os 9 | import math 10 | from scipy.signal import find_peaks 11 | from rule_based_programs.microprograms.dive_recognition_functions import * 12 | 13 | ### All functions (excluding helper functions) take "dive_data" as input, 14 | ### which is the dictionary with all the information outputted by the AQA metaprogram 15 | 16 | ############## HELPER FUNCTIONS ###################################### 17 | def rotation_direction(vector1, vector2, threshold=0.4): 18 | # Calculate the determinant to determine rotation direction 19 | determinant = vector1[0] * vector2[1] - vector1[1] * vector2[0] 20 | mag1= np.linalg.norm(vector1) 21 | mag2= np.linalg.norm(vector2) 22 | norm_det = determinant/(mag1*mag2) 23 | if norm_det > threshold: 24 | # "counterclockwise" 25 | return 1 26 | elif norm_det < 0-threshold: 27 | # "clockwise" 28 | return -1 29 | else: 30 | # "not determinent" 31 | return 0 32 | 33 | def find_angle(vector1, vector2): 34 | unit_vector_1 = vector1 / np.linalg.norm(vector1) 35 | unit_vector_2 = vector2 / np.linalg.norm(vector2) 36 | dot_product = np.dot(unit_vector_1, unit_vector_2) 37 | angle = math.degrees(np.arccos(dot_product)) 38 | return angle 39 | 40 | ################################################################# 41 | 42 | def height_off_board_score(dive_data): 43 | above_board_indices = [i for i in range(0, len(dive_data['distance_from_board'])) if dive_data['above_boards'][i]==1] 44 | takeoff_indices = [i for i in range(0, len(dive_data['takeoff'])) if dive_data['takeoff'][i]==1] 45 | final_indices = [] 46 | prev_board_end_coord = None 47 | for i in range(len(above_board_indices)): 48 | board_end_coord = dive_data['board_end_coords'][i] 49 | if board_end_coord is not None and board_end_coord[1] < 30: 50 | continue 51 | if board_end_coord is not None and prev_board_end_coord is not None and math.dist(board_end_coord, prev_board_end_coord) > 150: 52 | continue 53 | if above_board_indices[i] not in takeoff_indices: 54 | final_indices.append(above_board_indices[i]) 55 | prev_board_end_coord = board_end_coord 56 | 57 | heights = [] 58 | for i in range(len(final_indices)): 59 | pose_pred = dive_data['pose_pred'][final_indices[i]] 60 | board_end_coord = dive_data['board_end_coords'][final_indices[i]] 61 | if pose_pred is None or board_end_coord is None: 62 | continue 63 | pose_pred = pose_pred[0] 64 | min_height = float('inf') 65 | for j in range(len(pose_pred)): 66 | if board_end_coord[1] - pose_pred[j][1]< min_height: 67 | min_height = board_end_coord[1] - pose_pred[j][1] 68 | if min_height < 0: 69 | min_height = 0 70 | heights.append(min_height) 71 | if len(heights) == 0: 72 | return None, None 73 | max_scaled_height = max(heights) / get_scale_factor(dive_data) 74 | return max_scaled_height, final_indices[np.argmax(heights)] 75 | 76 | def distance_from_board_score(dive_data): 77 | above_board_indices = [i for i in range(0, len(dive_data['distance_from_board'])) if dive_data['above_boards'][i]==1] 78 | takeoff_indices = [i for i in range(0, len(dive_data['takeoff'])) if dive_data['takeoff'][i]==1] 79 | final_indices = [] 80 | for i in range(len(above_board_indices)): 81 | if above_board_indices[i] not in takeoff_indices: 82 | final_indices.append(above_board_indices[i]) 83 | dists = np.array(dive_data['distance_from_board'])[final_indices] 84 | for i in range(len(dists)): 85 | if dists[i] is None: 86 | dists[i] = float('inf') 87 | min_scaled_dist = np.min(dists) / get_scale_factor(dive_data) 88 | too_close_threshold = 0.25 89 | if 'diveNum' in dive_data: 90 | if dive_data['diveNum'][0] == '4': 91 | too_far_threshold = 1.1 92 | if dive_data['diveNum'][0] == '1': 93 | too_far_threshold = 1.6 94 | if dive_data['diveNum'][0] == '2': 95 | too_far_threshold = 1.8 96 | if dive_data['diveNum'][0] == '3': 97 | too_far_threshold = 1.6 98 | if dive_data['diveNum'][0] == '5': 99 | too_far_threshold = 1.5 100 | if dive_data['diveNum'][0] == '6': 101 | too_far_threshold = 1.1 102 | else: 103 | too_far_threshold = 1.8 104 | # good distance 105 | if min_scaled_dist < too_far_threshold and min_scaled_dist > too_close_threshold: 106 | return 0, min_scaled_dist, final_indices[np.argmin(dists)] 107 | # too far 108 | if min_scaled_dist >= too_far_threshold: 109 | return 1, min_scaled_dist, final_indices[np.argmin(dists)] 110 | # too close 111 | if min_scaled_dist <= too_close_threshold: 112 | return -1, min_scaled_dist, final_indices[np.argmin(dists)] 113 | return min_scaled_dist 114 | 115 | def knee_bend_score(dive_data): 116 | if find_position(dive_data) == 'tuck': 117 | return None, None 118 | knee_bends = [] 119 | for i in range(len(dive_data['pose_pred'])): 120 | if dive_data['som'][i] == 0: 121 | continue 122 | pose_pred = dive_data['pose_pred'][i] 123 | if pose_pred is None: 124 | continue 125 | pose_pred = pose_pred[0] 126 | knee_to_ankle = [pose_pred[1][0] - pose_pred[0][0], 0-(pose_pred[1][1]-pose_pred[0][1])] 127 | knee_to_hip = [pose_pred[1][0] - pose_pred[2][0], 0-(pose_pred[1][1]-pose_pred[2][1])] 128 | knee_bend = find_angle(knee_to_ankle, knee_to_hip) 129 | knee_bends.append(knee_bend) 130 | if len(knee_bends) == 0: 131 | return None, None 132 | som_indices = [i for i in range(0, len(dive_data['som'])) if dive_data['som'][i]==1] 133 | som_avg_knee_bend = np.mean(knee_bends) 134 | return 180 - som_avg_knee_bend, som_indices 135 | 136 | def position_tightness_score(dive_data): 137 | som_indices = [i for i in range(0, len(dive_data['som'])) if dive_data['som'][i]==1] 138 | twist_indices = [i for i in range(0, len(dive_data['som'])) if dive_data['twist'][i]==1] 139 | som_tightness = np.array(dive_data['position_tightness'])[som_indices] 140 | twist_tightness = 180 - np.array(dive_data['position_tightness'])[twist_indices] 141 | 142 | # Compute the area using the composite trapezoidal rule. 143 | som_tightness = np.array(list(filter(lambda item: item is not None and item < 90, som_tightness))) 144 | twist_tightness = np.array(list(filter(lambda item: item is not None and item < 90, twist_tightness))) 145 | if len(som_indices) == 0: 146 | som_avg = None 147 | else: 148 | som_avg = np.mean(som_tightness) 149 | if len(twist_indices)==0: 150 | return som_avg, None, som_indices, twist_indices 151 | twist_avg = np.mean(twist_tightness) 152 | if som_avg is not None: 153 | som_avg -= 15 154 | return som_avg, twist_avg, som_indices, twist_indices 155 | 156 | def is_rotating_clockwise(dive_data): 157 | directions = [] 158 | for i in range(1, len(dive_data['pose_pred'])): 159 | if dive_data['pose_pred'][i] is None or dive_data['pose_pred'][i-1] is None: 160 | continue 161 | if dive_data['on_boards'][i] == 0: 162 | prev_pose_pred_hip = dive_data['pose_pred'][i-1][0][3] 163 | curr_pose_pred_hip = dive_data['pose_pred'][i][0][3] 164 | prev_pose_pred_knee = dive_data['pose_pred'][i-1][0][4] 165 | curr_pose_pred_knee = dive_data['pose_pred'][i][0][4] 166 | prev_hip_knee = [prev_pose_pred_knee[0] - prev_pose_pred_hip[0], 0-(prev_pose_pred_knee[1] - prev_pose_pred_hip[1])] 167 | curr_hip_knee = [curr_pose_pred_knee[0] - curr_pose_pred_hip[0], 0-(curr_pose_pred_knee[1] - curr_pose_pred_hip[1])] 168 | direction = rotation_direction(prev_hip_knee, curr_hip_knee, threshold=0) 169 | directions.append(direction) 170 | return np.sum(directions) < 0 171 | 172 | def over_under_rotation_score(dive_data): 173 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1] 174 | over_under_rotation_error = np.array(dive_data['over_under_rotation'])[entry_indices] 175 | splashes = np.array(dive_data['splash'])[entry_indices] 176 | for i in range(len(over_under_rotation_error) - 1, -1, -1): 177 | if over_under_rotation_error[i] is None: 178 | continue 179 | else: 180 | # gets the second to last pose (assuming the last pose has incorrect pose estimation) 181 | index = i-2 182 | if index < 0: 183 | index = 0 184 | total_index = entry_indices[index] 185 | if splashes[index] is None and over_under_rotation_error[index] is not None: 186 | pose_pred = dive_data['pose_pred'][total_index][0] 187 | thorax_pelvis_vector = [pose_pred[1][0] - pose_pred[7][0], 0-(pose_pred[1][1]-pose_pred[7][1])] 188 | prev_pose_pred = dive_data['pose_pred'][total_index - 1] 189 | if prev_pose_pred is not None: 190 | prev_pose_pred = prev_pose_pred[0] 191 | prev_thorax_pelvis_vector = [prev_pose_pred[1][0] - prev_pose_pred[7][0], 0-(prev_pose_pred[1][1]-prev_pose_pred[7][1])] 192 | rotation_speed = find_angle(thorax_pelvis_vector, prev_thorax_pelvis_vector) 193 | else: 194 | rotation_speed = 10 195 | vector2 = [0, 1] 196 | if is_rotating_clockwise(dive_data): 197 | # if under-rotated 198 | if thorax_pelvis_vector[0] < 0: 199 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) - rotation_speed 200 | 201 | else: 202 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) + rotation_speed 203 | else: 204 | # if over-rotated 205 | if thorax_pelvis_vector[0] < 0: 206 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) + rotation_speed 207 | else: 208 | avg_leg_torso = find_angle(thorax_pelvis_vector, vector2) - rotation_speed 209 | return np.abs(avg_leg_torso), entry_indices[index] 210 | break 211 | 212 | def straightness_during_entry_score(dive_data): 213 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1] 214 | straightness_during_entry = np.array(dive_data['position_tightness'])[entry_indices] 215 | over_under_rotation = over_under_rotation_score(dive_data) 216 | if over_under_rotation is not None: 217 | frame = over_under_rotation[1] 218 | index = entry_indices.index(frame) - 1 219 | if index < 0: 220 | index = 0 221 | return 180-straightness_during_entry[index], [frame-1, frame, frame + 1] 222 | splashes = np.array(dive_data['splash'])[entry_indices] 223 | for i in range(len(straightness_during_entry) - 1, -1, -1): 224 | if i > 0 and (straightness_during_entry[i] is None or splashes[i] is not None): 225 | continue 226 | else: 227 | # gets the second to last pose (assuming the last pose has incorrect pose estimation) 228 | if straightness_during_entry[i] is not None: 229 | if 180-straightness_during_entry[i] > 130: 230 | continue 231 | return 180-straightness_during_entry[i], entry_indices[i-1:i+2] 232 | break 233 | 234 | def splash_score(dive_data): 235 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1] 236 | if len(entry_indices) == 0: 237 | return None 238 | splash_indices=[i for i in range(0, len(dive_data['splash'])) if dive_data['splash'][i] is not None] 239 | splashes = np.array(dive_data['splash'])[entry_indices] 240 | for i in range(len(splashes)): 241 | if splashes[i] is None: 242 | splashes[i] = 0 243 | splashes = splashes / get_scale_factor(dive_data)**2 244 | # area under curve 245 | area = np.trapz(splashes, dx=5) 246 | return area, entry_indices[np.argmax(splashes)], splash_indices 247 | 248 | # feet apart 249 | def feet_apart_score(dive_data): 250 | takeoff_indices = [i for i in range(0, len(dive_data['takeoff'])) if dive_data['takeoff'][i]==1] 251 | non_takeoff_indices = [i for i in range(len(dive_data['takeoff'])) if (i not in takeoff_indices and dive_data['splash'][i] is None)] 252 | feet_apart_error = np.array(dive_data['feet_apart'])[non_takeoff_indices] 253 | for i in range(len(feet_apart_error)): 254 | if feet_apart_error[i] is None or math.isnan(feet_apart_error[i]): 255 | feet_apart_error[i] = 0 256 | peaks, _ = find_peaks(feet_apart_error, height=5) 257 | if len(peaks) >= 1: 258 | peak_indices = np.array(non_takeoff_indices)[peaks] 259 | else: 260 | peak_indices = [] 261 | area = np.mean(feet_apart_error) 262 | 263 | return area, peak_indices 264 | 265 | def find_position(dive_data): 266 | angles = [] 267 | three_in_a_row = 0 268 | for i in range(1, len(dive_data['pose_pred'])): 269 | pose_pred = dive_data['pose_pred'][i] 270 | if pose_pred is None or dive_data['som'][i]==0: 271 | continue 272 | pose_pred = pose_pred[0] 273 | l_knee = pose_pred[4] 274 | l_ankle = pose_pred[5] 275 | l_hip = pose_pred[3] 276 | l_knee_ankle = [l_ankle[0] - l_knee[0], 0-(l_ankle[1] - l_knee[1])] 277 | l_knee_hip = [l_hip[0] - l_knee[0], 0-(l_hip[1] - l_knee[1])] 278 | angle = find_angle(l_knee_ankle, l_knee_hip) 279 | angles.append(angle) 280 | # print(angle) 281 | if angle < 70: 282 | three_in_a_row += 1 283 | if three_in_a_row >=3: 284 | return 'tuck' 285 | else: 286 | three_in_a_row =0 287 | return 'pike' 288 | 289 | def get_position_from_diveNum(dive_data): 290 | diveNum = dive_data['diveNum'] 291 | position_code = diveNum[-1] 292 | if position_code == 'a': 293 | return "straight" 294 | elif position_code == 'b': 295 | return "pike" 296 | elif position_code == 'c': 297 | return "tuck" 298 | elif position_code == 'd': 299 | return "free" 300 | else: 301 | return None 302 | 303 | def get_all_report_scores(dive_data): 304 | with open('distribution_data.pkl', 'rb') as f: 305 | distribution_data = pickle.load(f) 306 | ## handstand and som_count## 307 | expected_som, handstand = som_counter_full_dive(dive_data) 308 | ## twist_count 309 | expected_twists = twist_counter_full_dive(dive_data) 310 | ## direction: front, back, reverse, inward 311 | expected_direction = get_direction(dive_data) 312 | dive_data['is_handstand'] = handstand 313 | dive_data['direction'] = expected_direction 314 | 315 | intermediate_scores = {} 316 | all_percentiles = [] 317 | entry_indices = [i for i in range(0, len(dive_data['entry'])) if dive_data['entry'][i]==1] 318 | 319 | ### height off board ### 320 | if dive_data['is_handstand']: 321 | error_scores = distribution_data['armstand_height_off_board_scores'] 322 | elif expected_twists >0: 323 | error_scores = distribution_data['twist_height_off_board_scores'] 324 | elif dive_data['direction']=='front': 325 | error_scores = distribution_data['front_height_off_board_scores'] 326 | elif dive_data['direction']=='back': 327 | error_scores = distribution_data['back_height_off_board_scores'] 328 | elif dive_data['direction']=='reverse': 329 | error_scores = distribution_data['reverse_height_off_board_scores'] 330 | elif dive_data['direction']=='inward': 331 | error_scores = distribution_data['inward_height_off_board_scores'] 332 | error_scores = list(filter(lambda item: item is not None, error_scores)) 333 | intermediate_scores['height_off_board'] = {} 334 | if dive_data['is_handstand']: 335 | intermediate_scores['height_off_board']['raw_score'] = None 336 | intermediate_scores['height_off_board']['frame_index'] = None 337 | else: 338 | intermediate_scores['height_off_board']['raw_score'] = height_off_board_score(dive_data)[0] 339 | intermediate_scores['height_off_board']['frame_index'] = height_off_board_score(dive_data)[1] 340 | err = intermediate_scores['height_off_board']['raw_score'] 341 | if err is not None: 342 | temp = error_scores 343 | temp.append(err) 344 | temp.sort() 345 | intermediate_scores['height_off_board']['percentile'] = temp.index(err)/len(temp) 346 | all_percentiles.append(temp.index(err)/len(temp)) 347 | else: 348 | intermediate_scores['height_off_board']['percentile'] = None 349 | 350 | ## distance from board #### 351 | error_scores = distribution_data['distance_from_board_scores'] 352 | error_scores = list(filter(lambda item: item is not None, error_scores)) 353 | intermediate_scores['distance_from_board'] = {} 354 | intermediate_scores['distance_from_board']['raw_score'] = distance_from_board_score(dive_data)[1] 355 | intermediate_scores['distance_from_board']['frame_index'] = distance_from_board_score(dive_data)[2] 356 | err = distance_from_board_score(dive_data)[0] 357 | if err is not None: 358 | if err == 1: 359 | intermediate_scores['distance_from_board']['percentile'] = "safe, but too far from" 360 | intermediate_scores['distance_from_board']['score'] = 0.5 361 | 362 | elif err == 0: 363 | intermediate_scores['distance_from_board']['percentile'] = "a good distance from" 364 | intermediate_scores['distance_from_board']['score'] = 1 365 | else: 366 | intermediate_scores['distance_from_board']['percentile'] = "too close to" 367 | intermediate_scores['distance_from_board']['score'] = 0 368 | all_percentiles.append(intermediate_scores['distance_from_board']['score']) 369 | else: 370 | intermediate_scores['distance_from_board']['percentile'] = None 371 | intermediate_scores['distance_from_board']['score'] = None 372 | 373 | ### feet_apart_scores ### 374 | error_scores = distribution_data['feet_apart_scores'] 375 | error_scores = list(filter(lambda item: item is not None, error_scores)) 376 | intermediate_scores['feet_apart'] = {} 377 | intermediate_scores['feet_apart']['raw_score'] = feet_apart_score(dive_data)[0] 378 | intermediate_scores['feet_apart']['peaks'] = feet_apart_score(dive_data)[1] 379 | err = intermediate_scores['feet_apart']['raw_score'] 380 | if err is not None: 381 | temp = error_scores 382 | temp.append(err) 383 | temp.sort() 384 | intermediate_scores['feet_apart']['percentile'] = 1-temp.index(err)/len(temp) 385 | all_percentiles.append(1-temp.index(err)/len(temp)) 386 | else: 387 | intermediate_scores['feet_apart']['percentile'] = None 388 | 389 | ### knee_bend_scores ### 390 | error_scores = distribution_data['knee_bend_scores'] 391 | error_scores = list(filter(lambda item: item is not None, error_scores)) 392 | intermediate_scores['knee_bend'] = {} 393 | intermediate_scores['knee_bend']['raw_score'] = knee_bend_score(dive_data)[0] 394 | intermediate_scores['knee_bend']['frame_indices'] = knee_bend_score(dive_data)[1] 395 | err = intermediate_scores['knee_bend']['raw_score'] 396 | if err is not None: 397 | temp = error_scores 398 | temp.append(err) 399 | temp.sort() 400 | intermediate_scores['knee_bend']['percentile'] = 1-temp.index(err)/len(temp) 401 | all_percentiles.append(1-temp.index(err)/len(temp)) 402 | else: 403 | intermediate_scores['knee_bend']['percentile'] = None 404 | 405 | 406 | ### som_position_tightness_scores ### 407 | error_scores = distribution_data['som_position_tightness_scores'] 408 | error_scores = list(filter(lambda item: item is not None, error_scores)) 409 | intermediate_scores['som_position_tightness'] = {} 410 | position = find_position(dive_data) 411 | if position == 'tuck': 412 | intermediate_scores['som_position_tightness']['position'] = 'tuck' 413 | else: 414 | intermediate_scores['som_position_tightness']['position'] = 'pike' 415 | intermediate_scores['som_position_tightness']['raw_score'] = position_tightness_score(dive_data)[0] 416 | intermediate_scores['som_position_tightness']['frame_indices'] = position_tightness_score(dive_data)[2] 417 | err = intermediate_scores['som_position_tightness']['raw_score'] 418 | if err is not None: 419 | temp = error_scores 420 | temp.append(err) 421 | temp.sort() 422 | intermediate_scores['som_position_tightness']['percentile'] = 1-temp.index(err)/len(temp) 423 | all_percentiles.append(1-temp.index(err)/len(temp)) 424 | else: 425 | intermediate_scores['som_position_tightness']['percentile'] = None 426 | 427 | ### twist_position_tightness_scores ### 428 | error_scores = distribution_data['twist_position_tightness_scores'] 429 | error_scores = list(filter(lambda item: item is not None, error_scores)) 430 | intermediate_scores['twist_position_tightness'] = {} 431 | intermediate_scores['twist_position_tightness']['raw_score'] = position_tightness_score(dive_data)[1] 432 | intermediate_scores['twist_position_tightness']['frame_indices'] = position_tightness_score(dive_data)[3] 433 | err = intermediate_scores['twist_position_tightness']['raw_score'] 434 | if err is not None: 435 | temp = error_scores 436 | temp.append(err) 437 | temp.sort() 438 | intermediate_scores['twist_position_tightness']['percentile'] = 1-temp.index(err)/len(temp) 439 | all_percentiles.append(1-temp.index(err)/len(temp)) 440 | else: 441 | intermediate_scores['twist_position_tightness']['percentile'] = None 442 | 443 | ### over_under_rotation_scores ### 444 | error_scores = distribution_data['over_under_rotation_scores'] 445 | error_scores = list(filter(lambda item: item is not None, error_scores)) 446 | intermediate_scores['over_under_rotation'] = {} 447 | if over_under_rotation_score(dive_data) is not None: 448 | intermediate_scores['over_under_rotation']['raw_score'] = over_under_rotation_score(dive_data)[0] 449 | intermediate_scores['over_under_rotation']['frame_index'] = over_under_rotation_score(dive_data)[1] 450 | else: 451 | intermediate_scores['over_under_rotation']['raw_score'] = None 452 | intermediate_scores['over_under_rotation']['frame_index'] = None 453 | err = intermediate_scores['over_under_rotation']['raw_score'] 454 | if err is not None: 455 | temp = error_scores 456 | temp.append(err) 457 | temp.sort() 458 | intermediate_scores['over_under_rotation']['percentile'] = 1-temp.index(err)/len(temp) 459 | all_percentiles.append(1-temp.index(err)/len(temp)) 460 | 461 | else: 462 | intermediate_scores['over_under_rotation']['percentile'] = None 463 | 464 | ### splash_scores ### 465 | error_scores = distribution_data['splash_scores'] 466 | error_scores = list(filter(lambda item: item is not None, error_scores)) 467 | intermediate_scores['splash'] = {} 468 | intermediate_scores['splash']['raw_score'] = splash_score(dive_data)[0] 469 | intermediate_scores['splash']['maximum_index'] = splash_score(dive_data)[1] 470 | intermediate_scores['splash']['frame_indices'] = splash_score(dive_data)[2] 471 | 472 | err = intermediate_scores['splash']['raw_score'] 473 | if err is not None: 474 | temp = error_scores 475 | temp.append(err) 476 | temp.sort() 477 | intermediate_scores['splash']['percentile'] = 1-temp.index(err)/len(temp) 478 | all_percentiles.append(1-temp.index(err)/len(temp)) 479 | 480 | else: 481 | intermediate_scores['splash']['percentile'] = None 482 | 483 | ### straightness_during_entry_scores ### 484 | error_scores = distribution_data['straightness_during_entry_scores'] 485 | error_scores = list(filter(lambda item: item is not None, error_scores)) 486 | intermediate_scores['straightness_during_entry'] = {} 487 | if straightness_during_entry_score(dive_data) is not None: 488 | intermediate_scores['straightness_during_entry']['raw_score'] = straightness_during_entry_score(dive_data)[0] 489 | intermediate_scores['straightness_during_entry']['frame_indices'] = straightness_during_entry_score(dive_data)[1] 490 | else: 491 | intermediate_scores['straightness_during_entry']['raw_score'] = None 492 | intermediate_scores['straightness_during_entry']['frame_index'] = None 493 | 494 | err = intermediate_scores['straightness_during_entry']['raw_score'] 495 | if err is not None: 496 | temp = error_scores 497 | temp.append(err) 498 | temp.sort() 499 | intermediate_scores['straightness_during_entry']['percentile'] = 1-temp.index(err)/len(temp) 500 | all_percentiles.append(1-temp.index(err)/len(temp)) 501 | 502 | else: 503 | intermediate_scores['straightness_during_entry']['percentile'] = None 504 | 505 | ## overall score ### 506 | # Excellent: 10 507 | # Very Good: 8.5-9.5 508 | # Good: 7.0-8.0 509 | # Satisfactory: 5.0-6.5 510 | # Deficient: 2.5-4.5 511 | # Unsatisfactory: 0.5-2.0 512 | # Completely failed: 0 513 | overall_score = np.mean(all_percentiles) * 10 514 | intermediate_scores['overall_score'] = {} 515 | intermediate_scores['overall_score']['raw_score'] = overall_score 516 | if overall_score == 10: 517 | intermediate_scores['overall_score']['description'] = 'excellent' 518 | elif overall_score >=8.5 and overall_score <10: 519 | intermediate_scores['overall_score']['description'] = 'very good' 520 | elif overall_score >=7 and overall_score <8.5: 521 | intermediate_scores['overall_score']['description'] = 'good' 522 | elif overall_score >=5 and overall_score <7: 523 | intermediate_scores['overall_score']['description'] = 'satisfactory' 524 | elif overall_score >=2.5 and overall_score <5: 525 | intermediate_scores['overall_score']['description'] = 'deficient' 526 | elif overall_score >0 and overall_score <2.5: 527 | intermediate_scores['overall_score']['description'] = 'unsatisfactory' 528 | else: 529 | intermediate_scores['overall_score']['description'] = 'completely failed' 530 | 531 | return intermediate_scores 532 | -------------------------------------------------------------------------------- /score_report_generation/generate_report_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | generate_report_functions.py 3 | Author: Lauren Okamoto 4 | """ 5 | 6 | from jinja2 import Environment, FileSystemLoader 7 | import pickle 8 | import os 9 | import numpy as np 10 | from PIL import Image, ImageDraw 11 | from io import BytesIO 12 | import cv2 13 | import base64 14 | from pathlib import Path 15 | import torch 16 | import gradio as gr 17 | 18 | ############ Generate GIF functions ################################################################################### 19 | def generate_gif(local_directory, image_names, speed_factor=1, loop=0): 20 | """ 21 | Generate a GIF from a sequence of images paths saved in a local directory. 22 | 23 | Parameters: 24 | - local_directory (str): Directory path where the images are located 25 | - image_paths (list): List of filenames of the input images. 26 | - speed_factor (int): How fast the GIF is, the higher the less the delay is between frames 27 | - loop (int): Number of loops (0 for infinite loop). 28 | 29 | Returns: 30 | - Bytes of GIF 31 | """ 32 | images = [] 33 | durations = [] 34 | for image_name in image_names: 35 | img = Image.open(os.path.join(local_directory, image_name)) 36 | images.append(img) 37 | try: 38 | duration = img.info['duration'] 39 | except KeyError: 40 | duration = 100 # Default duration in case 'duration' is not available 41 | durations.append(duration) 42 | 43 | # Calculate the adjusted durations based on the speed factor 44 | adjusted_durations = [int(duration / speed_factor) for duration in durations] 45 | 46 | # Save as GIF to an in-memory buffer 47 | gif_buffer = BytesIO() 48 | images[0].save(gif_buffer, format='GIF', save_all=True, append_images=images[1:], duration=adjusted_durations, loop=loop) 49 | 50 | # Get the content of the buffer as bytes 51 | gif_content = gif_buffer.getvalue() 52 | gif_content = base64.b64encode(gif_content).decode('utf-8') 53 | return gif_content 54 | 55 | def generate_gif_from_frames(frames, speed_factor=1, loop=0, progress=gr.Progress()): 56 | """ 57 | Generate a GIF from a sequence of images. 58 | 59 | Parameters: 60 | - frames (list): List of cv2 frames 61 | - speed_factor (int): How fast the GIF is, the higher the less the delay is between frames 62 | - loop (int): Number of loops (0 for infinite loop). 63 | 64 | Returns: 65 | - Bytes of GIF 66 | """ 67 | durations = [] 68 | images = [] 69 | i = 0 70 | for frame in frames: 71 | image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 72 | images.append(image) 73 | duration = 100 # Default duration in case 'duration' is not available 74 | durations.append(duration) 75 | i+=1 76 | 77 | # Calculate the adjusted durations based on the speed factor 78 | adjusted_durations = [int(duration / speed_factor) for duration in durations] 79 | 80 | # Save as GIF to an in-memory buffer 81 | gif_buffer = BytesIO() 82 | images[0].save(gif_buffer, format='GIF', save_all=True, append_images=images[1:], duration=adjusted_durations, loop=loop) 83 | 84 | # Get the content of the buffer as bytes 85 | gif_content = gif_buffer.getvalue() 86 | gif_content = base64.b64encode(gif_content).decode('utf-8') 87 | return gif_content 88 | 89 | ########################################################################################################## 90 | ############ Overlay Symbols on Frames ###################################################################### 91 | 92 | def draw_pose(keypoints,img, board_end_coord): 93 | """draw the keypoints and the skeletons. 94 | :params keypoints: the shape should be equal to [17,2] 95 | :params img: 96 | """ 97 | assert keypoints.shape == (NUM_KPTS,2) 98 | SKELETON = [[1,2],[1,0],[2,6],[3,6],[4,5],[3,4],[6,7],[7,8],[9,8],[7,12],[7,13],[11,12],[13,14],[14,15],[10,11]] 99 | CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], 100 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], 101 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 102 | NUM_KPTS = 16 103 | for i in range(len(SKELETON)): 104 | kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1] 105 | x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1] 106 | x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1] 107 | cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1) 108 | cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1) 109 | cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2) 110 | 111 | def draw_symbols(opencv_image, pose_preds, board_end_coord, plat_outputs, splash_pred_mask, above_board=None): 112 | if pose_preds is not None: 113 | draw_pose(np.array(pose_preds)[0],opencv_image, board_end_coord) 114 | if above_board is None or above_board==1: 115 | draw_platform(opencv_image, plat_outputs) 116 | draw_splash(opencv_image, splash_pred_mask) 117 | return opencv_image 118 | 119 | def draw_platform(opencv_image, output): 120 | pred_classes = output['instances'].pred_classes.cpu().numpy() 121 | platforms = np.where(pred_classes == 0)[0] 122 | scores = output['instances'].scores[platforms] 123 | if len(scores) == 0: 124 | return 125 | pred_masks = output['instances'].pred_masks[platforms] 126 | max_instance = torch.argmax(scores) 127 | pred_mask = np.array(pred_masks[max_instance].cpu()) 128 | # Convert the mask to a binary image 129 | binary_mask = pred_mask.squeeze().astype(np.uint8) 130 | contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 131 | 132 | cv2.drawContours(opencv_image, contours[0], -1, (36, 255, 12), thickness=5) 133 | 134 | def draw_splash(opencv_image, pred_mask): 135 | # Convert the mask to a binary image 136 | if pred_mask is None: 137 | return 138 | binary_mask = pred_mask.squeeze().astype(np.uint8) 139 | contours, hierarchy = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 140 | 141 | cv2.drawContours(opencv_image, contours[0], -1, (0, 0, 0), thickness=5) 142 | 143 | ########################################################################################################## 144 | ############ Generate HTML template ###################################################################### 145 | 146 | def generate_report(template_path, data, local_directory, progress=gr.Progress()): 147 | # Load the template environment 148 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates')) 149 | 150 | file_names = os.listdir(local_directory) 151 | file_names.sort() 152 | file_names = np.array(file_names) 153 | progress(0.9, desc="Generating Score Report") 154 | # Load the template 155 | is_twister = data['twist_position_tightness']['raw_score'] is not None 156 | overall_score_desc = data['overall_score']['description'] 157 | overall_score = '%.1f' % data['overall_score']['raw_score'] 158 | 159 | feet_apart_score = round(float('%.2f' % data['feet_apart']['raw_score'])) 160 | feet_apart_peaks = file_names[data['feet_apart']['peaks']] 161 | feet_apart_gif = None 162 | has_feet_apart_peaks = False 163 | if len(feet_apart_peaks) > 0: 164 | has_feet_apart_peaks = True 165 | feet_apart_gif = generate_gif(local_directory, feet_apart_peaks, speed_factor = 0.05) 166 | feet_apart_percentile = round(float('%.2f' % (data['feet_apart']['percentile'] *100))) 167 | feet_apart_percentile_divided_by_ten = '%.1f' % (data['feet_apart']['percentile'] *10) 168 | 169 | include_height_off_platform = False 170 | height_off_board_score = data['height_off_board']['raw_score'] 171 | height_off_board_percentile = data['height_off_board']['percentile'] 172 | height_off_board_description = None 173 | height_off_board_percentile_divided_by_ten = None 174 | encoded_height_off_board_frame = None 175 | if height_off_board_score is not None: 176 | include_height_off_platform = True 177 | height_off_board_score = round(float('%.2f' % data['height_off_board']['raw_score'])) 178 | height_off_board_frame = file_names[data['height_off_board']['frame_index']] 179 | height_off_board_frame_path = os.path.join(local_directory, height_off_board_frame) 180 | with open(height_off_board_frame_path, "rb") as image_file: 181 | encoded_height_off_board_frame = base64.b64encode(image_file.read()).decode('utf-8') 182 | height_off_board_percentile = round(float('%.2f' % (data['height_off_board']['percentile'] *100))) 183 | height_off_board_percentile_divided_by_ten = '%.1f' % (data['height_off_board']['percentile'] *10) 184 | if float(height_off_board_percentile_divided_by_ten) > 5: 185 | height_off_board_description = "good" 186 | else: 187 | height_off_board_description = "a bit on the lower side" 188 | 189 | dist_from_board_score = '%.2f' % data['distance_from_board']['raw_score'] 190 | dist_from_board_frame = file_names[data['distance_from_board']['frame_index']] 191 | dist_from_board_frame_path = os.path.join(local_directory, dist_from_board_frame) 192 | with open(dist_from_board_frame_path, "rb") as image_file: 193 | encoded_dist_from_board_frame = base64.b64encode(image_file.read()).decode('utf-8') 194 | dist_from_board_percentile = data['distance_from_board']['percentile'] 195 | if 'good' in dist_from_board_percentile: 196 | dist_from_board_percentile_status = "Good" 197 | elif 'far' in dist_from_board_percentile: 198 | dist_from_board_percentile_status = "Too Far" 199 | else: 200 | dist_from_board_percentile_status = "Too Close" 201 | 202 | knee_bend_score = data['knee_bend']['raw_score'] 203 | knee_bend_percentile = data['knee_bend']['percentile'] 204 | knee_bend_frames = [] 205 | knee_bend_percentile_divided_by_ten = None 206 | if knee_bend_score is not None: 207 | knee_bend_score = round(float('%.2f' % (data['knee_bend']['raw_score']))) 208 | knee_bend_frames = file_names[data['knee_bend']['frame_indices']] 209 | knee_bend_percentile = round(float('%.2f' % (knee_bend_percentile * 100))) 210 | knee_bend_percentile_divided_by_ten = '%.1f' % (data['knee_bend']['percentile'] * 10) 211 | 212 | som_position_tightness_score = data['som_position_tightness']['raw_score'] 213 | som_position_tightness_percentile = data['som_position_tightness']['percentile'] 214 | som_position_tightness_position = data['som_position_tightness']['position'] 215 | som_position_tightness_frames = file_names[data['som_position_tightness']['frame_indices']] 216 | som_position_tightness_gif = None 217 | som_position_tightness_percentile_divided_by_ten = None 218 | if som_position_tightness_score is not None: 219 | if is_twister: 220 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score'] + 15))) 221 | else: 222 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score']))) 223 | som_position_tightness_gif = generate_gif(local_directory, som_position_tightness_frames) 224 | som_position_tightness_percentile = round(float('%.2f' % (som_position_tightness_percentile * 100))) 225 | som_position_tightness_percentile_divided_by_ten = '%.1f' % (data['som_position_tightness']['percentile'] * 10) 226 | twist_position_tightness_score = data['twist_position_tightness']['raw_score'] 227 | twist_position_tightness_frames = [] 228 | twist_position_tightness_gif = None 229 | twist_position_tightness_percentile = None 230 | twist_position_tightness_percentile_divided_by_ten = None 231 | if twist_position_tightness_score is not None: 232 | twist_position_tightness_score = round(float('%.2f' % twist_position_tightness_score)) 233 | twist_position_tightness_frames = file_names[data['twist_position_tightness']['frame_indices']] 234 | twist_position_tightness_gif = generate_gif(local_directory, twist_position_tightness_frames) 235 | twist_position_tightness_percentile = round(float('%.2f' % (data['twist_position_tightness']['percentile'] * 100))) 236 | twist_position_tightness_percentile_divided_by_ten = '%.1f' % (data['twist_position_tightness']['percentile'] * 10) 237 | over_under_rotation_score = round(float('%.2f' % data['over_under_rotation']['raw_score'])) 238 | over_under_rotation_frame = file_names[data['over_under_rotation']['frame_index']] 239 | over_under_rotation_percentile = round(float('%.2f' % (data['over_under_rotation']['percentile'] * 100))) 240 | over_under_rotation_percentile_divided_by_ten = '%.1f' % (data['over_under_rotation']['percentile'] * 10) 241 | straightness_during_entry_score = round(float('%.2f' % data['straightness_during_entry']['raw_score'])) 242 | straightness_during_entry_frames = file_names[data['straightness_during_entry']['frame_indices']] 243 | straightness_during_entry_gif = generate_gif(local_directory, straightness_during_entry_frames, speed_factor = 0.5) 244 | straightness_during_entry_percentile = round(float('%.2f' % (data['straightness_during_entry']['percentile'] * 100))) 245 | straightness_during_entry_percentile_divided_by_ten = '%.1f' % (data['straightness_during_entry']['percentile'] * 10) 246 | splash_score = round(float('%.2f' % data['splash']['raw_score'])) 247 | splash_frame = file_names[data['splash']['maximum_index']] 248 | splash_indices = file_names[data['splash']['frame_indices']] 249 | splash_gif = None 250 | if len(splash_indices) > 0: 251 | splash_gif = generate_gif(local_directory, splash_indices) 252 | splash_percentile = round(float('%.2f' % (data['splash']['percentile'] * 100))) 253 | splash_percentile_divided_by_ten = '%.1f' % (data['splash']['percentile'] * 10) 254 | if float(splash_percentile) < 50: 255 | splash_description = 'on the larger side' 256 | else: 257 | splash_description = 'small' 258 | template = env.get_template(template_path) 259 | data = { 260 | 'local_directory': local_directory, 261 | 'is_twister' : is_twister, 262 | 'overall_score_desc' : overall_score_desc, 263 | 'overall_score' : overall_score, 264 | 'feet_apart_score' : feet_apart_score, 265 | 'feet_apart_peaks' : feet_apart_peaks, 266 | 'has_feet_apart_peaks' : has_feet_apart_peaks, 267 | 'feet_apart_gif' : feet_apart_gif, 268 | 'feet_apart_percentile' : feet_apart_percentile, 269 | 'feet_apart_percentile_divided_by_ten': feet_apart_percentile_divided_by_ten, 270 | 271 | 'include_height_off_platform': include_height_off_platform, 272 | 'height_off_board_score' : height_off_board_score, 273 | 'height_off_board_percentile' : height_off_board_percentile, 274 | 'encoded_height_off_board_frame' : encoded_height_off_board_frame, 275 | 'height_off_board_percentile_divided_by_ten': height_off_board_percentile_divided_by_ten, 276 | 'height_off_board_description' : height_off_board_description, 277 | 278 | 'dist_from_board_score' : dist_from_board_score, 279 | 'dist_from_board_frame' : dist_from_board_frame, 280 | 'encoded_dist_from_board_frame': encoded_dist_from_board_frame, 281 | 'dist_from_board_percentile' : dist_from_board_percentile, 282 | 'dist_from_board_percentile_status': dist_from_board_percentile_status, 283 | 284 | 'knee_bend_score' : knee_bend_score, 285 | 'knee_bend_frames' : knee_bend_frames, 286 | 'knee_bend_percentile' : knee_bend_percentile, 287 | 'knee_bend_percentile_divided_by_ten' : knee_bend_percentile_divided_by_ten, 288 | 289 | 'som_position_tightness_score' : som_position_tightness_score, 290 | 'som_position_tightness_frames' : som_position_tightness_frames, 291 | 'som_position_tightness_gif' : som_position_tightness_gif, 292 | 'som_position_tightness_position' : som_position_tightness_position, 293 | 'som_position_tightness_percentile' : som_position_tightness_percentile, 294 | 'som_position_tightness_percentile_divided_by_ten' : som_position_tightness_percentile_divided_by_ten, 295 | 'twist_position_tightness_score' : twist_position_tightness_score, 296 | 'twist_position_tightness_frames': twist_position_tightness_frames, 297 | 'twist_position_tightness_percentile' : twist_position_tightness_percentile, 298 | 'twist_position_tightness_percentile_divided_by_ten' : twist_position_tightness_percentile_divided_by_ten, 299 | 'twist_position_tightness_gif' : twist_position_tightness_gif, 300 | 'over_under_rotation_score' : over_under_rotation_score, 301 | 'over_under_rotation_frame' : over_under_rotation_frame, 302 | 'over_under_rotation_percentile' : over_under_rotation_percentile, 303 | 'over_under_rotation_percentile_divided_by_ten' : over_under_rotation_percentile_divided_by_ten, 304 | 'straightness_during_entry_score' : straightness_during_entry_score, 305 | 'straightness_during_entry_gif' : straightness_during_entry_gif, 306 | 'straightness_during_entry_percentile' : straightness_during_entry_percentile, 307 | 'straightness_during_entry_percentile_divided_by_ten': straightness_during_entry_percentile_divided_by_ten, 308 | 'splash_score' : splash_score, 309 | 'splash_frame' : splash_frame, 310 | 'splash_gif' : splash_gif, 311 | 'splash_percentile' : splash_percentile, 312 | 'splash_percentile_divided_by_ten': splash_percentile_divided_by_ten, 313 | 'splash_description' : splash_description, 314 | } 315 | # Render the template with the provided data 316 | report_content = template.render(data) 317 | return report_content 318 | 319 | 320 | def generate_report_from_frames(template_path, data, frames): 321 | # Load the template environment 322 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates')) 323 | 324 | frames = np.array(frames) 325 | # Load the template 326 | is_twister = data['twist_position_tightness']['raw_score'] is not None 327 | overall_score_desc = data['overall_score']['description'] 328 | overall_score = '%.1f' % data['overall_score']['raw_score'] 329 | 330 | feet_apart_score = round(float('%.2f' % data['feet_apart']['raw_score'])) 331 | feet_apart_peaks = frames[data['feet_apart']['peaks']] 332 | feet_apart_gif = None 333 | has_feet_apart_peaks = False 334 | if len(feet_apart_peaks) > 0: 335 | has_feet_apart_peaks = True 336 | feet_apart_gif = generate_gif_from_frames(feet_apart_peaks, speed_factor = 0.05) 337 | feet_apart_percentile = round(float('%.2f' % (data['feet_apart']['percentile'] *100))) 338 | feet_apart_percentile_divided_by_ten = '%.1f' % (data['feet_apart']['percentile'] *10) 339 | 340 | include_height_off_platform = False 341 | height_off_board_score = data['height_off_board']['raw_score'] 342 | height_off_board_frame = None 343 | height_off_board_percentile = None 344 | height_off_board_percentile_divided_by_ten = None 345 | height_off_board_description = None 346 | encoded_height_off_board_frame = None 347 | if height_off_board_score is not None: 348 | include_height_off_platform = True 349 | height_off_board_score = round(float('%.2f' % data['height_off_board']['raw_score'])) 350 | height_off_board_frame = Image.fromarray(cv2.cvtColor(frames[data['height_off_board']['frame_index']], cv2.COLOR_BGR2RGB)) 351 | height_buffer = BytesIO() 352 | height_off_board_frame.save(height_buffer, format='JPEG') 353 | encoded_height_off_board_frame = base64.b64encode(height_buffer.getvalue()).decode('utf-8') 354 | height_off_board_percentile = round(float('%.2f' % (data['height_off_board']['percentile'] *100))) 355 | height_off_board_percentile_divided_by_ten = '%.1f' % (data['height_off_board']['percentile'] *10) 356 | if float(height_off_board_percentile_divided_by_ten) > 5: 357 | height_off_board_description = "good" 358 | else: 359 | height_off_board_description = "a bit on the lower side" 360 | 361 | dist_from_board_score = '%.2f' % data['distance_from_board']['raw_score'] 362 | dist_from_board_frame = Image.fromarray(cv2.cvtColor(frames[data['distance_from_board']['frame_index']], cv2.COLOR_BGR2RGB)) 363 | dist_buffer = BytesIO() 364 | dist_from_board_frame.save(dist_buffer, format='JPEG') 365 | encoded_dist_from_board_frame = base64.b64encode(dist_buffer.getvalue()).decode('utf-8') 366 | dist_from_board_percentile = data['distance_from_board']['percentile'] 367 | if 'good' in dist_from_board_percentile: 368 | dist_from_board_percentile_status = "Good" 369 | elif 'far' in dist_from_board_percentile: 370 | dist_from_board_percentile_status = "Too Far" 371 | else: 372 | dist_from_board_percentile_status = "Too Close" 373 | 374 | knee_bend_score = data['knee_bend']['raw_score'] 375 | knee_bend_percentile = data['knee_bend']['percentile'] 376 | knee_bend_frames = [] 377 | knee_bend_percentile_divided_by_ten = None 378 | if knee_bend_score is not None: 379 | knee_bend_score = round(float('%.2f' % knee_bend_score)) 380 | knee_bend_percentile = round(float('%.2f' % (knee_bend_percentile * 100))) 381 | knee_bend_frames = frames[data['knee_bend']['frame_indices']] 382 | knee_bend_percentile_divided_by_ten = '%.1f' % (data['knee_bend']['percentile'] * 10) 383 | 384 | som_position_tightness_score = data['som_position_tightness']['raw_score'] 385 | som_position_tightness_percentile = data['som_position_tightness']['percentile'] 386 | som_position_tightness_position = data['som_position_tightness']['position'] 387 | som_position_tightness_frames = [] 388 | som_position_tightness_gif = None 389 | som_position_tightness_percentile_divided_by_ten = None 390 | if som_position_tightness_score is not None: 391 | if is_twister: 392 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score'] + 15))) 393 | else: 394 | som_position_tightness_score = round(float('%.2f' % (data['som_position_tightness']['raw_score']))) 395 | som_position_tightness_frames = frames[data['som_position_tightness']['frame_indices']] 396 | som_position_tightness_gif = generate_gif_from_frames(som_position_tightness_frames) 397 | som_position_tightness_percentile = round(float('%.2f' % (som_position_tightness_percentile * 100))) 398 | som_position_tightness_percentile_divided_by_ten = '%.1f' % (data['som_position_tightness']['percentile'] * 10) 399 | 400 | twist_position_tightness_score = data['twist_position_tightness']['raw_score'] 401 | twist_position_tightness_frames = [] 402 | twist_position_tightness_gif = None 403 | twist_position_tightness_percentile = None 404 | twist_position_tightness_percentile_divided_by_ten = None 405 | if twist_position_tightness_score is not None: 406 | twist_position_tightness_score = round(float('%.2f' % twist_position_tightness_score)) 407 | twist_position_tightness_frames = frames[data['twist_position_tightness']['frame_indices']] 408 | twist_position_tightness_gif = generate_gif_from_frames(twist_position_tightness_frames) 409 | twist_position_tightness_percentile = round(float('%.2f' % (data['twist_position_tightness']['percentile'] * 100))) 410 | twist_position_tightness_percentile_divided_by_ten = '%.1f' % (data['twist_position_tightness']['percentile'] * 10) 411 | 412 | over_under_rotation_score = round(float('%.2f' % data['over_under_rotation']['raw_score'])) 413 | over_under_rotation_frame = frames[data['over_under_rotation']['frame_index']] 414 | over_under_rotation_percentile = round(float('%.2f' % (data['over_under_rotation']['percentile'] * 100))) 415 | over_under_rotation_percentile_divided_by_ten = '%.1f' % (data['over_under_rotation']['percentile'] * 10) 416 | 417 | straightness_during_entry_score = round(float('%.2f' % data['straightness_during_entry']['raw_score'])) 418 | straightness_during_entry_frames = frames[data['straightness_during_entry']['frame_indices']] 419 | straightness_during_entry_gif = generate_gif_from_frames(straightness_during_entry_frames, speed_factor = 0.5) 420 | straightness_during_entry_percentile = round(float('%.2f' % (data['straightness_during_entry']['percentile'] * 100))) 421 | straightness_during_entry_percentile_divided_by_ten = '%.1f' % (data['straightness_during_entry']['percentile'] * 10) 422 | 423 | splash_score = round(float('%.2f' % data['splash']['raw_score'])) 424 | splash_frame = frames[data['splash']['maximum_index']] 425 | splash_indices = frames[data['splash']['frame_indices']] 426 | splash_gif = None 427 | if len(splash_indices) > 0: 428 | splash_gif = generate_gif_from_frames(splash_indices) 429 | splash_percentile = round(float('%.2f' % (data['splash']['percentile'] * 100))) 430 | splash_percentile_divided_by_ten = '%.1f' % (data['splash']['percentile'] * 10) 431 | if float(splash_percentile) < 50: 432 | splash_description = 'on the larger side' 433 | else: 434 | splash_description = 'small' 435 | template = env.get_template(template_path) 436 | data = { 437 | 'is_twister' : is_twister, 438 | 'overall_score_desc' : overall_score_desc, 439 | 'overall_score' : overall_score, 440 | 'feet_apart_score' : feet_apart_score, 441 | 'feet_apart_peaks' : feet_apart_peaks, 442 | 'has_feet_apart_peaks' : has_feet_apart_peaks, 443 | 'feet_apart_gif' : feet_apart_gif, 444 | 'feet_apart_percentile' : feet_apart_percentile, 445 | 'feet_apart_percentile_divided_by_ten': feet_apart_percentile_divided_by_ten, 446 | 'include_height_off_platform': include_height_off_platform, 447 | 'height_off_board_score' : height_off_board_score, 448 | 'height_off_board_percentile' : height_off_board_percentile, 449 | 'encoded_height_off_board_frame' : encoded_height_off_board_frame, 450 | 'height_off_board_percentile_divided_by_ten': height_off_board_percentile_divided_by_ten, 451 | 'height_off_board_description' : height_off_board_description, 452 | 'dist_from_board_score' : dist_from_board_score, 453 | 'dist_from_board_frame' : dist_from_board_frame, 454 | 'encoded_dist_from_board_frame': encoded_dist_from_board_frame, 455 | 'dist_from_board_percentile' : dist_from_board_percentile, 456 | 'dist_from_board_percentile_status': dist_from_board_percentile_status, 457 | 'knee_bend_score' : knee_bend_score, 458 | 'knee_bend_frames' : knee_bend_frames, 459 | 'knee_bend_percentile' : knee_bend_percentile, 460 | 'knee_bend_percentile_divided_by_ten' : knee_bend_percentile_divided_by_ten, 461 | 'som_position_tightness_score' : som_position_tightness_score, 462 | 'som_position_tightness_frames' : som_position_tightness_frames, 463 | 'som_position_tightness_gif' : som_position_tightness_gif, 464 | 'som_position_tightness_position' : som_position_tightness_position, 465 | 'som_position_tightness_percentile' : som_position_tightness_percentile, 466 | 'som_position_tightness_percentile_divided_by_ten' : som_position_tightness_percentile_divided_by_ten, 467 | 'twist_position_tightness_score' : twist_position_tightness_score, 468 | 'twist_position_tightness_frames': twist_position_tightness_frames, 469 | 'twist_position_tightness_percentile' : twist_position_tightness_percentile, 470 | 'twist_position_tightness_percentile_divided_by_ten' : twist_position_tightness_percentile_divided_by_ten, 471 | 'twist_position_tightness_gif' : twist_position_tightness_gif, 472 | 'over_under_rotation_score' : over_under_rotation_score, 473 | 'over_under_rotation_frame' : over_under_rotation_frame, 474 | 'over_under_rotation_percentile' : over_under_rotation_percentile, 475 | 'over_under_rotation_percentile_divided_by_ten' : over_under_rotation_percentile_divided_by_ten, 476 | 'straightness_during_entry_score' : straightness_during_entry_score, 477 | 'straightness_during_entry_gif' : straightness_during_entry_gif, 478 | 'straightness_during_entry_percentile' : straightness_during_entry_percentile, 479 | 'straightness_during_entry_percentile_divided_by_ten': straightness_during_entry_percentile_divided_by_ten, 480 | 'splash_score' : splash_score, 481 | 'splash_frame' : splash_frame, 482 | 'splash_gif' : splash_gif, 483 | 'splash_percentile' : splash_percentile, 484 | 'splash_percentile_divided_by_ten': splash_percentile_divided_by_ten, 485 | 'splash_description' : splash_description, 486 | } 487 | # Render the template with the provided data 488 | report_content = template.render(data) 489 | return report_content 490 | 491 | def generate_symbols_report(template_path, dive_data, frames): 492 | # Load the template environment 493 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates')) 494 | template = env.get_template(template_path) 495 | pose_frames = [] 496 | for i in range(len(dive_data['pose_pred'])): 497 | pose_frame = draw_symbols(frames[i], dive_data['pose_pred'][i], dive_data['board_end_coords'][i], dive_data['plat_outputs'][i], dive_data['splash_pred_masks'][i]) 498 | pose_frames.append(pose_frame) 499 | pose_gif = generate_gif_from_frames(pose_frames, speed_factor=2) 500 | pose_data = {} 501 | pose_data['pose_gif'] = pose_gif 502 | html = template.render(pose_data) 503 | return html 504 | 505 | def generate_symbols_report_precomputed(template_path, dive_data, local_directory, progress=gr.Progress()): 506 | # Load the template environment 507 | file_names = os.listdir(local_directory) 508 | file_names.sort() 509 | file_names = np.array(file_names) 510 | 511 | if 'above_boards' in dive_data: 512 | above_boards = dive_data['above_boards'] 513 | else: 514 | above_boards = [None] * len(file_names) 515 | 516 | env = Environment(loader=FileSystemLoader('./score_report_generation/templates')) 517 | template = env.get_template(template_path) 518 | pose_frames = [] 519 | counter = 0 520 | for i in range(len(file_names)): 521 | progress(i/(len(file_names)+10), desc="Abstracting Symbols") 522 | if file_names[i][-4:] != ".jpg": 523 | continue 524 | opencv_image = cv2.imread(local_directory+file_names[i]) 525 | pose_frame = draw_symbols(opencv_image, dive_data['pose_pred'][counter], dive_data['board_end_coords'][counter], dive_data['plat_outputs'][counter], dive_data['splash_pred_masks'][counter], above_board=above_boards[counter]) 526 | pose_frames.append(pose_frame) 527 | counter +=1 528 | pose_gif = generate_gif_from_frames(pose_frames, speed_factor=2, progress=progress) 529 | pose_data = {} 530 | pose_data['pose_gif'] = pose_gif 531 | html = template.render(pose_data) 532 | return html 533 | --------------------------------------------------------------------------------