├── configs ├── soccernet_delta │ ├── spotting_labels.csv │ ├── spotting_class_weights.csv │ ├── evaluation_tolerances.json │ └── nms_windows.csv ├── soccernet_cards_confidence │ ├── nms_windows.csv │ ├── spotting_labels.csv │ ├── evaluation_tolerances.json │ └── spotting_class_weights.csv ├── soccernet_challenge_confidence │ ├── nms_windows.csv │ ├── spotting_labels.csv │ ├── spotting_class_weights.csv │ └── evaluation_tolerances.json ├── soccernet_delta_soft_nms │ ├── spotting_labels.csv │ ├── evaluation_tolerances.json │ ├── spotting_class_weights.csv │ └── nms_windows.csv ├── soccernet_challenge_delta │ ├── spotting_labels.csv │ ├── evaluation_tolerances.json │ ├── spotting_class_weights.csv │ └── nms_windows.csv ├── soccernet_v1_confidence │ ├── evaluation_tolerances.json │ ├── nms_windows.csv │ └── spotting_labels.csv ├── soccernet_challenge_delta_soft_nms │ ├── spotting_labels.csv │ ├── evaluation_tolerances.json │ ├── spotting_class_weights.csv │ └── nms_windows.csv └── soccernet_confidence │ ├── evaluation_tolerances.json │ ├── spotting_class_weights.csv │ ├── nms_windows.csv │ ├── spotting_labels.csv │ └── segmentation_labels.csv ├── data └── splits │ ├── BallValid.json │ ├── BallTest.json │ ├── BallChallenge.json │ ├── BallTrain.json │ ├── LICENSING_NOTICES.md │ ├── SoccerNetCameraChangesValid.json │ ├── SoccerNetCameraChangesTest.json │ ├── SoccerNetGamesChallenge.json │ ├── SoccerNetCameraChangesChallenge.json │ ├── SoccerNetGamesTest.json │ └── SoccerNetGamesValid.json ├── bin ├── test.py ├── profile_validation.py ├── project_features.py ├── command_user_constants.py ├── extract_features.py ├── make_pca_features.py ├── train.py ├── create_features_from_results.py ├── create_normalizer.py ├── create_multi_task_games_csv.py ├── make_pca_transform.py ├── predict_on_videos.py ├── create_averaging_predictor.py ├── transform_features.py └── evaluate_spotting_jsons.py ├── spivak ├── data │ ├── output_names.py │ ├── video_io.py │ ├── label_map.py │ ├── soccernet_constants.py │ ├── dataset.py │ ├── video_chunk_iterator.py │ └── dataset_splits.py ├── evaluation │ ├── task_evaluation.py │ ├── aggregate.py │ └── segmentation_evaluation.py ├── models │ ├── predictor.py │ ├── projector.py │ ├── assembly │ │ ├── huggingface_activations.py │ │ ├── bottom_up.py │ │ └── weight_creator.py │ ├── delta_dense_predictor.py │ ├── trainer.py │ ├── sam_model.py │ ├── averaging_predictor.py │ └── non_maximum_suppression.py ├── application │ ├── worker_manager.py │ ├── feature_utils.py │ └── command_utils.py ├── html_visualization │ ├── utils.py │ └── segmentation_visualization.py ├── feature_extraction │ ├── soccernet_v2.py │ └── extraction.py └── video_visualization │ └── recognition_visualization.py ├── PULL_REQUEST_TEMPLATE.md ├── Contributing.md ├── setup.py ├── .gitignore └── Code_of_Conduct.md /configs/soccernet_delta/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/spotting_labels.csv -------------------------------------------------------------------------------- /configs/soccernet_cards_confidence/nms_windows.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/nms_windows.csv -------------------------------------------------------------------------------- /configs/soccernet_challenge_confidence/nms_windows.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/nms_windows.csv -------------------------------------------------------------------------------- /configs/soccernet_delta_soft_nms/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_delta/spotting_labels.csv -------------------------------------------------------------------------------- /configs/soccernet_cards_confidence/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/spotting_labels.csv -------------------------------------------------------------------------------- /configs/soccernet_challenge_confidence/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/spotting_labels.csv -------------------------------------------------------------------------------- /configs/soccernet_delta/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/spotting_class_weights.csv -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_challenge_confidence/spotting_labels.csv -------------------------------------------------------------------------------- /configs/soccernet_delta_soft_nms/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | ../soccernet_delta/evaluation_tolerances.json -------------------------------------------------------------------------------- /configs/soccernet_delta_soft_nms/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_delta/spotting_class_weights.csv -------------------------------------------------------------------------------- /configs/soccernet_v1_confidence/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/evaluation_tolerances.json -------------------------------------------------------------------------------- /configs/soccernet_cards_confidence/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/evaluation_tolerances.json -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta_soft_nms/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_challenge_delta/spotting_labels.csv -------------------------------------------------------------------------------- /configs/soccernet_challenge_confidence/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_confidence/spotting_class_weights.csv -------------------------------------------------------------------------------- /configs/soccernet_v1_confidence/nms_windows.csv: -------------------------------------------------------------------------------- 1 | label,window 2 | Goal,30.0 3 | Card,30.0 4 | Substitution,30.0 5 | -------------------------------------------------------------------------------- /configs/soccernet_v1_confidence/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | id,name,order 2 | 0,Goal,0 3 | 1,Card,1 4 | 2,Substitution,2 5 | -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | ../soccernet_challenge_confidence/evaluation_tolerances.json -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_challenge_confidence/spotting_class_weights.csv -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta_soft_nms/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | ../soccernet_challenge_delta/evaluation_tolerances.json -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta_soft_nms/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | ../soccernet_challenge_delta/spotting_class_weights.csv -------------------------------------------------------------------------------- /data/splits/BallValid.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_efl": { 3 | "2019-2020": [ 4 | "2019-10-01 - Middlesbrough - Preston North End" 5 | ] 6 | } 7 | } -------------------------------------------------------------------------------- /data/splits/BallTest.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_efl": { 3 | "2019-2020": [ 4 | "2019-10-01 - Reading - Fulham", 5 | "2019-10-01 - Stoke City - Huddersfield Town" 6 | ] 7 | } 8 | } -------------------------------------------------------------------------------- /data/splits/BallChallenge.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_efl": { 3 | "2019-2020": [ 4 | "2019-10-01 - Wigan Athletic - Birmingham City", 5 | "2019-10-02 - Cardiff City - Queens Park Rangers" 6 | ] 7 | } 8 | } -------------------------------------------------------------------------------- /configs/soccernet_confidence/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | { 2 | "sets": { 3 | "loose": [5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0], 4 | "tight": [1.0, 2.0, 3.0, 4.0, 5.0] 5 | }, 6 | "main": "loose", 7 | "extra": [0.5, 6.0, 8.0] 8 | } 9 | -------------------------------------------------------------------------------- /configs/soccernet_challenge_confidence/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | { 2 | "sets": { 3 | "loose": [5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0], 4 | "tight": [1.0, 2.0, 3.0, 4.0, 5.0] 5 | }, 6 | "main": "tight", 7 | "extra": [0.5, 6.0, 8.0] 8 | } 9 | -------------------------------------------------------------------------------- /configs/soccernet_delta/evaluation_tolerances.json: -------------------------------------------------------------------------------- 1 | { 2 | "sets": { 3 | "loose": [5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0], 4 | "medium": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0], 5 | "tight": [1.0, 2.0, 3.0, 4.0, 5.0] 6 | }, 7 | "main": "medium", 8 | "extra": [0.5] 9 | } 10 | -------------------------------------------------------------------------------- /data/splits/BallTrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_efl": { 3 | "2019-2020": [ 4 | "2019-10-01 - Blackburn Rovers - Nottingham Forest", 5 | "2019-10-01 - Brentford - Bristol City", 6 | "2019-10-01 - Hull City - Sheffield Wednesday", 7 | "2019-10-01 - Leeds United - West Bromwich" 8 | ] 9 | } 10 | } -------------------------------------------------------------------------------- /configs/soccernet_delta/nms_windows.csv: -------------------------------------------------------------------------------- 1 | label,window 2 | Penalty,3.0 3 | Kick-off,3.0 4 | Goal,3.0 5 | Substitution,3.0 6 | Offside,3.0 7 | Shots on target,3.0 8 | Shots off target,3.0 9 | Clearance,3.0 10 | Ball out of play,3.0 11 | Throw-in,3.0 12 | Foul,3.0 13 | Indirect free-kick,3.0 14 | Direct free-kick,3.0 15 | Corner,3.0 16 | Yellow card,3.0 17 | Red card,3.0 18 | Yellow->red card,3.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_delta_soft_nms/nms_windows.csv: -------------------------------------------------------------------------------- 1 | label,window 2 | Penalty,4.0 3 | Kick-off,4.0 4 | Goal,4.0 5 | Substitution,4.0 6 | Offside,4.0 7 | Shots on target,4.0 8 | Shots off target,4.0 9 | Clearance,4.0 10 | Ball out of play,4.0 11 | Throw-in,4.0 12 | Foul,4.0 13 | Indirect free-kick,4.0 14 | Direct free-kick,4.0 15 | Corner,4.0 16 | Yellow card,4.0 17 | Red card,4.0 18 | Yellow->red card,4.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta/nms_windows.csv: -------------------------------------------------------------------------------- 1 | label,window 2 | Penalty,3.0 3 | Kick-off,3.0 4 | Goal,3.0 5 | Substitution,3.0 6 | Offside,3.0 7 | Shots on target,3.0 8 | Shots off target,3.0 9 | Clearance,3.0 10 | Ball out of play,3.0 11 | Throw-in,3.0 12 | Foul,3.0 13 | Indirect free-kick,3.0 14 | Direct free-kick,3.0 15 | Corner,3.0 16 | Yellow card,3.0 17 | Red card,3.0 18 | Yellow->red card,3.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_challenge_delta_soft_nms/nms_windows.csv: -------------------------------------------------------------------------------- 1 | label,window 2 | Penalty,4.0 3 | Kick-off,4.0 4 | Goal,4.0 5 | Substitution,4.0 6 | Offside,4.0 7 | Shots on target,4.0 8 | Shots off target,4.0 9 | Clearance,4.0 10 | Ball out of play,4.0 11 | Throw-in,4.0 12 | Foul,4.0 13 | Indirect free-kick,4.0 14 | Direct free-kick,4.0 15 | Corner,4.0 16 | Yellow card,4.0 17 | Red card,4.0 18 | Yellow->red card,4.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_confidence/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | label,weight 2 | Penalty,1.0 3 | Kick-off,1.0 4 | Goal,1.0 5 | Substitution,1.0 6 | Offside,1.0 7 | Shots on target,1.0 8 | Shots off target,1.0 9 | Clearance,1.0 10 | Ball out of play,1.0 11 | Throw-in,1.0 12 | Foul,1.0 13 | Indirect free-kick,1.0 14 | Direct free-kick,1.0 15 | Corner,1.0 16 | Yellow card,1.0 17 | Red card,1.0 18 | Yellow->red card,1.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_cards_confidence/spotting_class_weights.csv: -------------------------------------------------------------------------------- 1 | label,weight 2 | Penalty,0.0 3 | Kick-off,0.0 4 | Goal,0.0 5 | Substitution,0.0 6 | Offside,0.0 7 | Shots on target,0.0 8 | Shots off target,0.0 9 | Clearance,0.0 10 | Ball out of play,0.0 11 | Throw-in,0.0 12 | Foul,0.0 13 | Indirect free-kick,0.0 14 | Direct free-kick,0.0 15 | Corner,0.0 16 | Yellow card,1.0 17 | Red card,1.0 18 | Yellow->red card,1.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_confidence/nms_windows.csv: -------------------------------------------------------------------------------- 1 | label,window 2 | Penalty,20.0 3 | Kick-off,20.0 4 | Goal,20.0 5 | Substitution,20.0 6 | Offside,20.0 7 | Shots on target,20.0 8 | Shots off target,20.0 9 | Clearance,20.0 10 | Ball out of play,20.0 11 | Throw-in,20.0 12 | Foul,20.0 13 | Indirect free-kick,20.0 14 | Direct free-kick,20.0 15 | Corner,20.0 16 | Yellow card,20.0 17 | Red card,20.0 18 | Yellow->red card,20.0 19 | -------------------------------------------------------------------------------- /configs/soccernet_confidence/spotting_labels.csv: -------------------------------------------------------------------------------- 1 | id,name,order 2 | 0,Penalty,14 3 | 1,Kick-off,9 4 | 2,Goal,13 5 | 3,Substitution,8 6 | 4,Offside,11 7 | 5,Shots on target,5 8 | 6,Shots off target,6 9 | 7,Clearance,4 10 | 8,Ball out of play,0 11 | 9,Throw-in,1 12 | 10,Foul,2 13 | 11,Indirect free-kick,3 14 | 12,Direct free-kick,10 15 | 13,Corner,7 16 | 14,Yellow card,12 17 | 15,Red card,15 18 | 16,Yellow->red card,16 19 | -------------------------------------------------------------------------------- /configs/soccernet_confidence/segmentation_labels.csv: -------------------------------------------------------------------------------- 1 | id,name,order 2 | 0,Main camera center,1 3 | 1,Close-up player or field referee,0 4 | 2,Main camera left,6 5 | 3,Main camera right,7 6 | 4,Goal line technology camera,12 7 | 5,Main behind the goal,5 8 | 6,Spider camera,10 9 | 7,Close-up side staff,2 10 | 8,Close-up corner,8 11 | 9,Close-up behind the goal,3 12 | 10,Inside the goal,11 13 | 11,Public,4 14 | 12,Other,9 15 | -------------------------------------------------------------------------------- /bin/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import logging 8 | 9 | from spivak.application.argument_parser import get_args 10 | from spivak.application.test_utils import test 11 | 12 | 13 | if __name__ == "__main__": 14 | logging.getLogger().setLevel(logging.DEBUG) 15 | test(get_args()) 16 | -------------------------------------------------------------------------------- /spivak/data/output_names.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | # These constants are used for naming output filenames that are written and 6 | # read to/from disk. 7 | OUTPUT_SEGMENTATION = "segmentation" 8 | OUTPUT_DETECTION_SCORE = "detection_score" 9 | OUTPUT_DETECTION_SCORE_NMS = "detection_score_nms" 10 | OUTPUT_DETECTION_THRESHOLDED = "detection_thresholded" 11 | OUTPUT_LABEL = "label" 12 | OUTPUT_TARGET = "target" 13 | -------------------------------------------------------------------------------- /spivak/evaluation/task_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from abc import ABCMeta, abstractmethod 6 | from typing import Dict 7 | 8 | 9 | class TaskEvaluation(metaclass=ABCMeta): 10 | 11 | """Defines the evaluation results for any given single task.""" 12 | 13 | @abstractmethod 14 | def scalars_for_logging(self) -> Dict[str, float]: 15 | pass 16 | 17 | @abstractmethod 18 | def summary(self) -> str: 19 | pass 20 | -------------------------------------------------------------------------------- /spivak/data/video_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from pathlib import Path 6 | from typing import List 7 | 8 | VIDEO_EXTENSIONS = {".mp4", ".avi", ".mkv"} 9 | 10 | 11 | def list_video_paths(input_videos_dir: Path) -> List[Path]: 12 | return sorted( 13 | [file_path 14 | for file_path in input_videos_dir.iterdir() 15 | if is_video_path(file_path)] 16 | ) 17 | 18 | 19 | def is_video_path(video_path: Path) -> bool: 20 | return video_path.suffix in VIDEO_EXTENSIONS 21 | -------------------------------------------------------------------------------- /spivak/models/predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from abc import ABCMeta, abstractmethod 6 | from pathlib import Path 7 | from typing import Dict 8 | 9 | import numpy as np 10 | 11 | from spivak.data.dataset import VideoDatum 12 | from spivak.models.non_maximum_suppression import FlexibleNonMaximumSuppression 13 | 14 | VideoOutputs = Dict[str, np.ndarray] 15 | 16 | 17 | class PredictorInterface(metaclass=ABCMeta): 18 | 19 | @abstractmethod 20 | def predict_video(self, video_datum: VideoDatum) -> VideoOutputs: 21 | pass 22 | 23 | @abstractmethod 24 | def predict_video_and_save( 25 | self, video_datum: VideoDatum, nms: FlexibleNonMaximumSuppression, 26 | base_path: Path) -> None: 27 | pass 28 | 29 | @abstractmethod 30 | def load_weights(self, weights_path: str) -> None: 31 | pass 32 | 33 | @abstractmethod 34 | def save_model(self, model_path: str) -> None: 35 | pass 36 | -------------------------------------------------------------------------------- /PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 3 | 4 | 5 | 6 | ### Description 7 | 8 | 9 | ### Related Issue 10 | 11 | 12 | 13 | 14 | ### How Has This Been Tested? 15 | 16 | 17 | 18 | 19 | ### Declaration 20 | 21 | 22 | I confirm that this contribution is made under the terms of the license found in the root directory of this repository's source tree and that I have the authority necessary to make this contribution on behalf of its copyright owner. 23 | -------------------------------------------------------------------------------- /data/splits/LICENSING_NOTICES.md: -------------------------------------------------------------------------------- 1 | Here we describe the licenses for the data files in this folder. These files 2 | are in formats that do not support comments, so we resorted to stating the 3 | licenses here in this markdown file. 4 | 5 | The following list of JSON files are `Copyright (c) 2023 SoccerNet`, under the 6 | MIT license. (You may obtain a copy of the MIT License at 7 | .) These files were copied from the 8 | SoccerNet pip package, version 0.1.48, which can be found at 9 | . 10 | - BallChallenge.json 11 | - BallTest.json 12 | - BallTrain.json 13 | - BallValid.json 14 | - SoccerNetCameraChangesChallenge.json 15 | - SoccerNetCameraChangesTrain.json 16 | - SoccerNetGamesChallenge.json 17 | - SoccerNetGamesTrain.json 18 | - SoccerNetCameraChangesTest.json 19 | - SoccerNetCameraChangesValid.json 20 | - SoccerNetGamesTest.json 21 | - SoccerNetGamesValid.json 22 | 23 | The following list of CSV files are `Copyright (c) 2023, Yahoo Inc.`, licensed 24 | under the Apache License, Version 2.0. See the accompanying LICENSE file for 25 | terms. 26 | - SoccerNetSpottingAndCameraChanges.csv 27 | - SoccerNetSpottingAndCameraChangesLarge.csv 28 | -------------------------------------------------------------------------------- /bin/profile_validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import cProfile 8 | import logging 9 | import pstats 10 | 11 | import tensorflow as tf 12 | 13 | from spivak.application.argument_parser import get_args 14 | from spivak.application.validation import \ 15 | create_evaluation as create_validation_evaluation 16 | 17 | ALLOW_MEMORY_GROWTH = False 18 | TIMING_FILENAME = "validation.prof" 19 | 20 | 21 | def main(): 22 | pr = cProfile.Profile() 23 | pr.enable() 24 | _run_validation() 25 | pr.disable() 26 | pr.dump_stats(TIMING_FILENAME) 27 | p = pstats.Stats(TIMING_FILENAME) 28 | p.sort_stats('cumulative').print_stats(30) 29 | import pdb 30 | pdb.set_trace() 31 | 32 | 33 | def _run_validation(): 34 | logging.getLogger().setLevel(logging.INFO) 35 | # disable_eager_execution is used here to match validation.py. 36 | tf.compat.v1.disable_eager_execution() 37 | if ALLOW_MEMORY_GROWTH: 38 | _allow_memory_growth() 39 | evaluation = create_validation_evaluation(get_args(), best_metric=0.0) 40 | print(f"Evaluation: {evaluation}") 41 | 42 | 43 | def _allow_memory_growth(): 44 | gpus = tf.config.list_physical_devices('GPU') 45 | for gpu in gpus: 46 | tf.config.experimental.set_memory_growth(gpu, True) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /bin/project_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import logging 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | 12 | from spivak.application.argument_parser import get_args 13 | from spivak.application.dataset_creation import create_label_maps 14 | from spivak.application.feature_utils import make_output_directories, \ 15 | VideoFeatureInfo, create_video_feature_infos 16 | from spivak.application.model_creation import load_projector 17 | from spivak.models.projector import Projector 18 | 19 | PROJECTED_FEATURE_NAME = "projected" 20 | 21 | 22 | def main() -> None: 23 | logging.getLogger().setLevel(logging.INFO) 24 | args = get_args() 25 | features_dir = Path(args.features_dir) 26 | results_dir = Path(args.results_dir) 27 | video_feature_infos = create_video_feature_infos( 28 | [features_dir], [args.feature_name], results_dir, 29 | PROJECTED_FEATURE_NAME) 30 | print(f"Found {len(video_feature_infos)} video feature files") 31 | results_dir.mkdir(parents=True, exist_ok=True) 32 | make_output_directories(video_feature_infos) 33 | label_maps = create_label_maps(args) 34 | projector = load_projector(args, label_maps) 35 | for video_feature_info in video_feature_infos: 36 | _create_projected_features_file(video_feature_info, projector) 37 | 38 | 39 | def _create_projected_features_file( 40 | video_feature_info: VideoFeatureInfo, projector: Projector) -> None: 41 | input_path = video_feature_info.input_paths[0] 42 | features = np.load(str(input_path)) 43 | projected = projector.project(features) 44 | print(f"Writing projected features to {video_feature_info.output_path}") 45 | np.save(str(video_feature_info.output_path), projected) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /Contributing.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | First, thanks for taking the time to contribute to our project! There are many ways you can help out. 3 | 4 | ### Questions 5 | 6 | If you have a question that needs an answer, [create an issue](https://help.github.com/articles/creating-an-issue/), and label it as a question. 7 | 8 | ### Issues for bugs or feature requests 9 | 10 | If you encounter any bugs in the code, or want to request a new feature or enhancement, please [create an issue](https://help.github.com/articles/creating-an-issue/) to report it. Kindly add a label to indicate what type of issue it is. 11 | 12 | ### Contribute Code 13 | We welcome your pull requests for bug fixes. To implement something new, please create an issue first so we can discuss it together. 14 | 15 | #### Creating a Pull Request 16 | Please follow [best practices](https://github.com/trein/dev-best-practices/wiki/Git-Commit-Best-Practices) for creating git commits. 17 | 18 | When your code is ready to be submitted, [submit a pull request](https://help.github.com/articles/creating-a-pull-request/) to begin the code review process. 19 | 20 | We only seek to accept code that you are authorized to contribute to the project. We have added a [pull request template](PULL_REQUEST_TEMPLATE.md) on our projects so that your contributions are made with the following confirmation: 21 | 22 | > I confirm that this contribution is made under the terms of the license found in the root directory of this repository's source tree and that I have the authority necessary to make this contribution on behalf of its copyright owner. 23 | 24 | ## Code of Conduct 25 | 26 | We encourage inclusive and professional interactions on our project. We welcome everyone to open an issue, improve the documentation, report bug or ssubmit a pull request. By participating in this project, you agree to abide by the [Yahoo Code of Conduct](Code_of_Conduct.md). If you feel there is a conduct issue related to this project, please raise it per the Code of Conduct process and we will address it. 27 | -------------------------------------------------------------------------------- /spivak/models/projector.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from typing import List, Tuple 6 | 7 | import numpy as np 8 | from tensorflow.python.keras import Model 9 | 10 | from spivak.data.video_chunk_iterator import VideoChunkIteratorProvider 11 | 12 | 13 | class Projector: 14 | 15 | def __init__( 16 | self, projector_model: Model, 17 | video_chunk_iterator_provider: VideoChunkIteratorProvider) -> None: 18 | self._projector_model = projector_model 19 | self._video_chunk_iterator_provider = video_chunk_iterator_provider 20 | 21 | def project(self, video_features: np.ndarray) -> np.ndarray: 22 | input_chunk_batch, valid_chunk_sizes = self._prepare_input_batch( 23 | video_features) 24 | chunk_output_batch = self._projector_model.predict(input_chunk_batch) 25 | valid_chunk_outputs = [ 26 | chunk_output_batch[c][0:valid_chunk_size] 27 | for c, valid_chunk_size in enumerate(valid_chunk_sizes)] 28 | return self._accumulate_outputs(valid_chunk_outputs, video_features) 29 | 30 | def _prepare_input_batch( 31 | self, features: np.ndarray) -> Tuple[np.ndarray, List[int]]: 32 | input_chunk_iterator = self._video_chunk_iterator_provider.provide( 33 | features) 34 | return input_chunk_iterator.prepare_input_batch() 35 | 36 | def _accumulate_outputs( 37 | self, valid_chunk_outputs: List[np.ndarray], 38 | features: np.ndarray) -> np.ndarray: 39 | chunk_output_iterator = self._video_chunk_iterator_provider.provide( 40 | features) 41 | num_frames = features.shape[0] 42 | num_projected_features = valid_chunk_outputs[0].shape[2] 43 | projected_features = np.zeros((num_frames, 1, num_projected_features)) 44 | chunk_output_iterator.accumulate_chunk_outputs( 45 | projected_features, valid_chunk_outputs) 46 | return np.squeeze(projected_features, axis=1) 47 | -------------------------------------------------------------------------------- /spivak/application/worker_manager.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import logging 6 | import warnings 7 | from multiprocessing import Queue, Process 8 | from typing import Callable, Optional, Tuple 9 | 10 | MODULE_ADDONS_INSTALL = "tensorflow_addons.utils.ensure_tf_install" 11 | 12 | 13 | class Manager: 14 | 15 | def __init__(self, process, input_queue, output_queue): 16 | self.process = process 17 | self.input_queue = input_queue 18 | self.output_queue = output_queue 19 | 20 | 21 | class ChildTask: 22 | 23 | def __init__(self, do_exit: bool, args: Optional[Tuple]) -> None: 24 | self.do_exit = do_exit 25 | self.args = args 26 | 27 | 28 | def manager_function( 29 | input_queue: Queue, output_queue: Queue, 30 | worker_function: Callable) -> None: 31 | # I tried also doing this with a pool (setting maxtasksperchild to 1), 32 | # but for some unknown reason it would sometimes not work (maybe a 33 | # deadlock), but I didn't investigate to understand why. 34 | logging.getLogger().setLevel(logging.ERROR) 35 | warnings.filterwarnings( 36 | action="ignore", category=UserWarning, module=MODULE_ADDONS_INSTALL) 37 | logging.info("MANAGER: initializing") 38 | worker_result_queue = Queue() 39 | 40 | def worker_function_with_queue(*worker_args) -> None: 41 | result = worker_function(*worker_args) 42 | worker_result_queue.put(result) 43 | 44 | do_exit = False 45 | while not do_exit: 46 | logging.info("MANAGER: waiting for a task") 47 | child_task = input_queue.get() 48 | do_exit = child_task.do_exit 49 | logging.info("MANAGER: Got a task") 50 | if not child_task.do_exit: 51 | child = Process( 52 | target=worker_function_with_queue, args=child_task.args) 53 | child.start() 54 | logging.info("MANAGER: Waiting for result") 55 | output_queue.put(worker_result_queue.get()) 56 | child.join() 57 | logging.info("MANAGER: done, reached end of function.") 58 | -------------------------------------------------------------------------------- /bin/command_user_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | # Options for memory usage setups. 6 | MEMORY_SETUP_256GB = "256" 7 | MEMORY_SETUP_64GB = "64" 8 | # Change the constants below according to your local setup. 9 | FEATURES_DIR = "./data/features/" 10 | BAIDU_FEATURES_DIR = "./data/features/baidu/" 11 | BAIDU_TWO_FEATURES_DIR = "./data/features/baidu_2.0/" 12 | RESNET_FEATURES_DIR = "./data/features/resnet/" 13 | RESNET_NORMALIZED_FEATURES_DIR = "./data/features/resnet_normalized/" 14 | LABELS_DIR = "./data/labels/" 15 | SPLITS_DIR = "./data/splits/" 16 | BASE_CONFIG_DIR = "./configs/" 17 | MODELS_DIR = "YOUR_MODELS_DIR" 18 | RESULTS_DIR = "YOUR_RESULTS_DIR" 19 | RUN_NAME = "first" 20 | MEMORY_SETUP = MEMORY_SETUP_256GB 21 | 22 | # You might have your data set up in such a way that all the SoccerNet features 23 | # and labels are under the same directory tree. In that case, each game folder 24 | # would contain its respective Baidu (Combination) features, ResNet features, 25 | # and action spotting labels. For example, if all your data is in a base 26 | # directory called "all_soccernet_features_and_labels/", then a game folder 27 | # might have the following files: 28 | # 29 | # $ ls all_soccernet_features_and_labels/england_epl/2014-2015/2015-02-21\ -\ 18-00\ Chelsea\ 1\ -\ 1\ Burnley/ 30 | # 1_baidu_soccer_embeddings.npy 2_baidu_soccer_embeddings.npy 31 | # 1_ResNET_TF2.npy 1_ResNET_TF2_PCA512.npy 2_ResNET_TF2.npy 32 | # 2_ResNET_TF2_PCA512.npy Labels-cameras.json Labels-v2.json 33 | # 34 | # In that case, you could just set some particular constants above to the same 35 | # base directory, as follows: 36 | # BAIDU_FEATURES_DIR = "all_soccernet_features_and_labels/" 37 | # BAIDU_TWO_FEATURES_DIR = "all_soccernet_features_and_labels/" 38 | # RESNET_FEATURES_DIR = "all_soccernet_features_and_labels/" 39 | # RESNET_NORMALIZED_FEATURES_DIR = "all_soccernet_features_and_labels/" 40 | # LABELS_DIR = "all_soccernet_features_and_labels/" 41 | # 42 | # We suggest setting FEATURES_DIR to a different folder, like 43 | # "./data/features/". At the same time, you probably won't have to change 44 | # SPLITS_DIR and BASE_CONFIG_DIR. 45 | -------------------------------------------------------------------------------- /bin/extract_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import logging 9 | from pathlib import Path 10 | from typing import Dict 11 | 12 | from spivak.data.video_io import list_video_paths 13 | from spivak.feature_extraction.extraction import extract_features_from_videos, \ 14 | create_feature_extractor, EXTRACTOR_TYPE_RESNET_TF2 15 | 16 | 17 | class Args: 18 | INPUT_VIDEOS_DIR = "input_dir" 19 | FEATURES_DIR = "features_dir" 20 | FEATURES_MODELS_DIR = "features_models_dir" 21 | FEATURES = "features" 22 | 23 | 24 | def main() -> None: 25 | args = _get_command_line_arguments() 26 | logging.getLogger().setLevel(logging.DEBUG) 27 | input_dir = Path(args[Args.INPUT_VIDEOS_DIR]) 28 | if not input_dir.is_dir(): 29 | raise ValueError(f"Input directory failed is_dir(): {input_dir}") 30 | features_dir = Path(args[Args.FEATURES_DIR]) 31 | features_dir.mkdir(parents=True, exist_ok=True) 32 | feature_extractor = create_feature_extractor( 33 | args[Args.FEATURES], Path(args[Args.FEATURES_MODELS_DIR])) 34 | video_paths = list_video_paths(input_dir) 35 | extract_features_from_videos(video_paths, features_dir, feature_extractor) 36 | 37 | 38 | def _get_command_line_arguments() -> Dict: 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument( 41 | "--" + Args.INPUT_VIDEOS_DIR, help="Input directory containing videos", 42 | required=True) 43 | parser.add_argument( 44 | "--" + Args.FEATURES_DIR, required=True, 45 | help="Directory in which to store intermediate video features") 46 | parser.add_argument( 47 | "--" + Args.FEATURES_MODELS_DIR, required=True, 48 | help="Directory containing models used for extracting video features") 49 | parser.add_argument( 50 | "--" + Args.FEATURES, required=False, 51 | help="What type of features to use", default=EXTRACTOR_TYPE_RESNET_TF2, 52 | choices=[EXTRACTOR_TYPE_RESNET_TF2]) 53 | args_dict = vars(parser.parse_args()) 54 | return args_dict 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | -------------------------------------------------------------------------------- /bin/make_pca_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | from pathlib import Path 9 | from typing import Dict 10 | 11 | import numpy as np 12 | 13 | from spivak.feature_extraction.extraction import PCATransformer, \ 14 | extractor_type_to_feature_name, EXTRACTOR_TYPE_RESNET_TF2 15 | 16 | DEFAULT_TAG = "PCA512" 17 | NUMPY_EXTENSION = ".npy" 18 | 19 | 20 | class Args: 21 | FEATURES_DIR = "features_dir" 22 | FEATURES = "features" 23 | PCA_PATH = "pca" 24 | TAG = "tag" 25 | 26 | 27 | def main() -> None: 28 | args = _get_command_line_arguments() 29 | pca_transformer = PCATransformer(Path(args[Args.PCA_PATH])) 30 | feature_name = extractor_type_to_feature_name(args[Args.FEATURES]) 31 | features_dir = Path(args[Args.FEATURES_DIR]) 32 | raw_features_paths = features_dir.glob( 33 | "**/*" + feature_name + NUMPY_EXTENSION) 34 | tag = args[Args.TAG] 35 | for raw_features_path in sorted(raw_features_paths): 36 | out_path = (raw_features_path.parent / 37 | f"{raw_features_path.stem}_{tag}{NUMPY_EXTENSION}") 38 | print(f"Going to generate {out_path} from {raw_features_path}.") 39 | raw_features = np.load(str(raw_features_path)) 40 | pca_features = pca_transformer.transform(raw_features) 41 | np.save(str(out_path), pca_features) 42 | 43 | 44 | def _get_command_line_arguments() -> Dict: 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument( 47 | "--" + Args.FEATURES_DIR, required=True, 48 | help="Dataset directory for reading and writing features") 49 | parser.add_argument( 50 | "--" + Args.PCA_PATH, required=True, 51 | help="Pickle containing the PCA transform") 52 | parser.add_argument( 53 | "--" + Args.TAG, required=False, 54 | help="Tag to append to output file name", default=DEFAULT_TAG) 55 | parser.add_argument( 56 | "--" + Args.FEATURES, required=False, 57 | help="What type of raw input features to use", 58 | default=EXTRACTOR_TYPE_RESNET_TF2, choices=[EXTRACTOR_TYPE_RESNET_TF2]) 59 | args_dict = vars(parser.parse_args()) 60 | return args_dict 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /spivak/data/label_map.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import csv 6 | from pathlib import Path 7 | from typing import Dict, List 8 | 9 | FIELD_NAME = "name" 10 | FIELD_ID = "id" 11 | FIELD_DISPLAY_ORDER = "order" 12 | 13 | 14 | class LabelMap: 15 | 16 | def __init__( 17 | self, int_to_label: Dict[int, str], 18 | display_order: List[int]) -> None: 19 | self.int_to_label = int_to_label 20 | self.label_to_int = { 21 | label: integer for integer, label in int_to_label.items()} 22 | # Map the ordered position of a label to its name. 23 | display_ordered_labels_dict = { 24 | ordered_position: int_to_label[label_int] 25 | for label_int, ordered_position in enumerate(display_order)} 26 | # Convert from dictionary to list. 27 | self.display_ordered_labels = [ 28 | display_ordered_labels_dict[ordered_position] 29 | for ordered_position in sorted(display_order)] 30 | 31 | def num_classes(self) -> int: 32 | return len(self.int_to_label) 33 | 34 | def write(self, label_map_file_path: Path) -> None: 35 | with label_map_file_path.open("w") as csv_file: 36 | writer = csv.DictWriter( 37 | csv_file, [FIELD_ID, FIELD_NAME, FIELD_DISPLAY_ORDER]) 38 | writer.writeheader() 39 | for order, label in enumerate(self.display_ordered_labels): 40 | row_dict = { 41 | FIELD_NAME: label, FIELD_ID: self.label_to_int[label], 42 | FIELD_DISPLAY_ORDER: order} 43 | writer.writerow(row_dict) 44 | 45 | @staticmethod 46 | def read_label_map(label_map_file_path: Path) -> "LabelMap": 47 | with label_map_file_path.open("r") as csv_file: 48 | reader = csv.DictReader(csv_file) 49 | int_to_label = { 50 | int(row[FIELD_ID]): row[FIELD_NAME] for row in reader} 51 | # It's hard to reset the reader, so I'll just open the file all over 52 | # again here. 53 | with label_map_file_path.open("r") as csv_file: 54 | reader = csv.DictReader(csv_file) 55 | display_order = [int(row[FIELD_DISPLAY_ORDER]) for row in reader] 56 | return LabelMap(int_to_label, display_order) 57 | -------------------------------------------------------------------------------- /bin/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2023, Yahoo Inc. 3 | # Licensed under the Apache License, Version 2.0. 4 | # See the accompanying LICENSE file for terms. 5 | 6 | from multiprocessing import Process, Queue 7 | 8 | from spivak.application.argument_parser import get_args, SharedArgs 9 | from spivak.application.worker_manager import manager_function, Manager 10 | 11 | 12 | def main() -> None: 13 | # Due to a memory leak in Keras, we compute our custom validation metric 14 | # (average-mAP) in a separate process. In contrast, the validation loss is 15 | # still computed within Keras, which might leak some memory as well, 16 | # but is manageable. Some people had luck getting around similar memory 17 | # leaks by either using tf.compat.v1.disable_eager_execution(), changing 18 | # the threading configuration, or running garbage collection every once in 19 | # a while, but those options didn't work out for me. See: 20 | # https://stackoverflow.com/questions/58137677/keras-model-training-memory-leak/58138230#58138230 21 | # https://github.com/tensorflow/tensorflow/issues/22098 22 | # 23 | # Additionally, another memory leak seems to be mitigated by running in 24 | # eager mode, which also allows us to clear the keras session while 25 | # running training to keep the memory usage low (which is only possible in 26 | # eager mode). In theory, eager mode is slower, but I didn't notice much 27 | # of a speed difference, so am using it. See: 28 | # https://github.com/tensorflow/tensorflow/issues/31312 29 | manager = _create_manager() 30 | manager.process.start() 31 | _import_and_train(get_args(), manager) 32 | manager.process.join() 33 | 34 | 35 | def _create_manager() -> Manager: 36 | input_queue = Queue() 37 | output_queue = Queue() 38 | manager_process = Process( 39 | target=manager_function, args=( 40 | input_queue, output_queue, _import_and_compute_validation_result)) 41 | return Manager(manager_process, input_queue, output_queue) 42 | 43 | 44 | def _import_and_train(args: SharedArgs, manager: Manager) -> None: 45 | from spivak.application.train_utils import train 46 | train(args, manager) 47 | 48 | 49 | def _import_and_compute_validation_result(args, best_metric, epoch): 50 | from spivak.application.validation import compute_validation_result 51 | return compute_validation_result(args, best_metric, epoch) 52 | 53 | 54 | if __name__ == "__main__": 55 | main() 56 | -------------------------------------------------------------------------------- /spivak/data/soccernet_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | # 5 | # This file incorporates work covered by the following copyright and permission 6 | # notice: 7 | # Copyright (c) 2021 Silvio Giancola 8 | # Licensed under the terms of the MIT license. 9 | # You may obtain a copy of the MIT License at https://opensource.org/licenses/MIT 10 | 11 | # This file contains pieces of code taken from the following file. At 12 | # Yahoo Inc., the original code was modified and new code was added. 13 | # https://github.com/SilvioGiancola/SoccerNetv2-DevKit/blob/20f2f74007c82b68a73c519dff852188df4a8b5a/Task2-CameraShotSegmentation/CALF-segmentation/src/config/classes.py 14 | 15 | # Label types 16 | LABEL_FILE_NAME = "Labels.json" # SoccerNet-v1 17 | LABEL_FILE_NAME_V2 = "Labels-v2.json" 18 | LABEL_FILE_NAME_V2_CAMERAS = "Labels-cameras.json" 19 | 20 | # Events as annotated in SoccerNet-v1. 21 | EVENT_DICTIONARY_V1 = { 22 | "soccer-ball": 0, "soccer-ball-own": 0, 23 | "r-card": 1, "y-card": 1, "yr-card": 1, 24 | "substitution-in": 2 25 | } 26 | 27 | # Events as annotated in SoccerNet-v2. 28 | EVENT_DICTIONARY_V2 = { 29 | "Penalty": 0, 30 | "Kick-off": 1, 31 | "Goal": 2, 32 | "Substitution": 3, 33 | "Offside": 4, 34 | "Shots on target": 5, 35 | "Shots off target": 6, 36 | "Clearance": 7, 37 | "Ball out of play": 8, 38 | "Throw-in": 9, 39 | "Foul": 10, 40 | "Indirect free-kick": 11, 41 | "Direct free-kick": 12, 42 | "Corner": 13, 43 | "Yellow card": 14, 44 | "Red card": 15, 45 | "Yellow->red card": 16 46 | } 47 | 48 | CAMERA_DICTIONARY = { 49 | "Main camera center": 0, 50 | "Close-up player or field referee": 1, 51 | "Main camera left": 2, 52 | "Main camera right": 3, 53 | "Goal line technology camera": 4, 54 | "Main behind the goal": 5, 55 | "Spider camera": 6, 56 | "Close-up side staff": 7, 57 | "Close-up corner": 8, 58 | "Close-up behind the goal": 9, 59 | "Inside the goal": 10, 60 | "Public": 11, 61 | # Note: the original SoccerNet-v2 code only has a lower-case "other" below, 62 | # whereas if we look at the actual camera labels in the JSON files, 63 | # we only find the upper-case "Other", but not the lower-case one. Maybe 64 | # the SoccerNet code was looking to skip "Other" labels, though it doesn't 65 | # skip the "I don't know" labels. 66 | "Other": 12, 67 | "other": 12, 68 | "I don't know": 12 69 | } 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import os 6 | import glob 7 | from setuptools import setup, find_namespace_packages 8 | 9 | SCRIPTS_DIR = "bin" 10 | PACKAGE_DIR = "spivak" 11 | PYTHON_REQUIRES = ">=3.6" 12 | INSTALL_REQUIRES = [ 13 | "numpy>=1.18.5,<1.23.0", 14 | "scipy>=1.4.1,<1.8.0", 15 | "scikit-learn>=0.24.2,<1.1.0", 16 | "pandas>=1.1.5,<1.4.0", 17 | "Pillow>=8.4.0,<9.1.0", 18 | "opencv-python>=4.5.4.58,<4.8.0", 19 | "tqdm~=4.62.3", 20 | # Note: some portion of the code here does not depend on TensorFlow, 21 | # so if you are not interested in using the models, you probably 22 | # don't really have to install TensorFlow. 23 | "tensorflow~=2.7.0", 24 | "tensorboard~=2.7.0", 25 | # This version of tensorflow_probability is needed so that it works with 26 | # Tensorflow 2.3. It also happens to work with 2.7. If you are using 2.7 27 | # or more recent versions of TensorFlow, you can probably upgrade this. 28 | "tensorflow_probability~=0.11.1", 29 | # tensorflow-addons has the decoupled weight decay functionality. 30 | # For TensorFlow 2.3, we need to use this older version (0.13) of the 31 | # addons package. This version also works with TensorFlow 2.7, even though 32 | # it will print out some warnings. 33 | "tensorflow-addons~=0.13.0", 34 | "typing_extensions==4.7.1", 35 | "plotly>=5.4.0,<5.6.0", 36 | # kaleido is used for being able to render pdf plots with plotly. 37 | "kaleido~=0.2.1", 38 | # scikit-video (skvideo), moviepy, imutils are only used in 39 | # spivak.feature_extraction.SoccerNetDataLoader.py. 40 | "scikit-video~=1.1.11", 41 | "moviepy~=1.0.3", 42 | "imageio-ffmpeg==0.4.9", 43 | "imutils~=0.5.4", 44 | # packaging is only used by 45 | # spivak.models.assembly.huggingface_activations.py. 46 | "packaging>=21.2,<21.4" 47 | ] 48 | EXTRAS_REQUIRES = { 49 | "av": [ 50 | # av is used only in bin/create_visualizations.py and inside 51 | # spivak.video_visualization, for creating videos with visualizations 52 | # of results. It requires the ffmpeg libraries, so we leave it as an 53 | # optional extra. 54 | "av~=10.0.0" 55 | ] 56 | } 57 | packages = find_namespace_packages(include=[f"{PACKAGE_DIR}.*"]) 58 | scripts = glob.glob(os.path.join(SCRIPTS_DIR, "*.py")) 59 | setup( 60 | name=PACKAGE_DIR, 61 | packages=packages, 62 | python_requires=PYTHON_REQUIRES, 63 | install_requires=INSTALL_REQUIRES, 64 | extras_require=EXTRAS_REQUIRES, 65 | scripts=scripts, 66 | ) 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | #### The NCAA dataset 107 | **/bball_dataset_april_4.csv 108 | 109 | ############# 110 | # Certain Pycharm/idea files 111 | ####### 112 | # User-specific stuff 113 | **/.idea/workspace.xml 114 | **/.idea/tasks.xml 115 | **/.idea/usage.statistics.xml 116 | **/.idea/dictionaries 117 | **/.idea/shelf 118 | 119 | 120 | # Generated files 121 | **/.idea/contentModel.xml 122 | 123 | # Sensitive or high-churn files 124 | **/.idea/dataSources/ 125 | **/.idea/dataSources.ids 126 | **/.idea/dataSources.local.xml 127 | **/.idea/sqlDataSources.xml 128 | **/.idea/dynamic.xml 129 | **/.idea/uiDesigner.xml 130 | **/.idea/dbnavigator.xml 131 | 132 | # Gradle 133 | **/.idea/gradle.xml 134 | **/.idea/libraries 135 | 136 | # Mongo Explorer plugin 137 | **/.idea/mongoSettings.xml 138 | 139 | # Cursive Clojure plugin 140 | **/.idea/replstate.xml 141 | 142 | # Editor-based Rest Client 143 | **/.idea/httpRequests 144 | 145 | ## File-based project format: 146 | *.iws 147 | 148 | # Emacs backups 149 | *~ -------------------------------------------------------------------------------- /spivak/models/assembly/huggingface_activations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # This file was originally taken from 16 | # https://github.com/huggingface/transformers/blob/d6ec54ba36769eb4920a350b3e323541dca17a59/src/transformers/activations_tf.py 17 | 18 | import math 19 | 20 | import tensorflow as tf 21 | from packaging import version 22 | 23 | 24 | def _gelu(x): 25 | """ 26 | Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when 27 | initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 28 | 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) Also see 29 | https://arxiv.org/abs/1606.08415 30 | """ 31 | x = tf.convert_to_tensor(x) 32 | cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype))) 33 | 34 | return x * cdf 35 | 36 | 37 | def _gelu_new(x): 38 | """ 39 | Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841 40 | 41 | Args: 42 | x: float Tensor to perform activation 43 | 44 | Returns: 45 | `x` with the GELU activation applied. 46 | """ 47 | x = tf.convert_to_tensor(x) 48 | pi = tf.cast(math.pi, x.dtype) 49 | coeff = tf.cast(0.044715, x.dtype) 50 | cdf = 0.5 * (1.0 + tf.tanh(tf.sqrt(2.0 / pi) * (x + coeff * tf.pow(x, 3)))) 51 | 52 | return x * cdf 53 | 54 | 55 | def mish(x): 56 | x = tf.convert_to_tensor(x) 57 | 58 | return x * tf.tanh(tf.math.softplus(x)) 59 | 60 | 61 | def gelu_fast(x): 62 | x = tf.convert_to_tensor(x) 63 | coeff1 = tf.cast(0.044715, x.dtype) 64 | coeff2 = tf.cast(0.7978845608, x.dtype) 65 | 66 | return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x))) 67 | 68 | 69 | if version.parse(tf.version.VERSION) >= version.parse("2.4"): 70 | 71 | def approximate_gelu_wrap(x): 72 | return tf.keras.activations.gelu(x, approximate=True) 73 | 74 | gelu = tf.keras.activations.gelu 75 | gelu_new = approximate_gelu_wrap 76 | else: 77 | gelu = _gelu 78 | gelu_new = _gelu_new 79 | -------------------------------------------------------------------------------- /bin/create_features_from_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | from pathlib import Path 9 | from typing import Dict 10 | 11 | import numpy as np 12 | 13 | from spivak.application.feature_utils import make_output_directories, \ 14 | VideoFeatureInfo, create_video_feature_infos 15 | from spivak.models.delta_dense_predictor import clip_frames 16 | from spivak.models.dense_predictor import OUTPUT_DELTA, OUTPUT_CONFIDENCE 17 | 18 | 19 | class Args: 20 | INPUT_DIRS = "input_dirs" 21 | OUTPUT_NAME = "output_name" 22 | OUTPUT_DIR = "output_dir" 23 | 24 | 25 | def main() -> None: 26 | args = _get_command_line_arguments() 27 | input_dirs = [Path(p) for p in args[Args.INPUT_DIRS]] 28 | for input_dir in input_dirs: 29 | if not input_dir.is_dir(): 30 | raise ValueError(f"Input directory failed is_dir(): {input_dir}") 31 | output_name = args[Args.OUTPUT_NAME] 32 | output_dir = Path(args[Args.OUTPUT_DIR]) 33 | video_infos = create_video_feature_infos( 34 | input_dirs, [output_name] * len(input_dirs), output_dir, output_name) 35 | print(f"Found {len(video_infos)} video result files") 36 | output_dir.mkdir(parents=True, exist_ok=True) 37 | make_output_directories(video_infos) 38 | for video_info in video_infos: 39 | _create_combined_file(video_info) 40 | 41 | 42 | def _create_combined_file(video_info: VideoFeatureInfo) -> None: 43 | results_list = [] 44 | for input_path in video_info.input_paths: 45 | input_results = np.load(str(input_path)) 46 | results_list.append(input_results) 47 | min_num_frames = min(results.shape[0] for results in results_list) 48 | clipped_results_list = [ 49 | clip_frames(results, min_num_frames) for results in results_list] 50 | combined_results = np.stack(clipped_results_list, axis=2) 51 | print(f"Writing combined results to {video_info.output_path}") 52 | np.save(str(video_info.output_path), combined_results) 53 | 54 | 55 | def _get_command_line_arguments() -> Dict: 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument( 58 | "--" + Args.INPUT_DIRS, 59 | help="One or more input directories containing results", 60 | nargs="+", required=True, type=str) 61 | parser.add_argument( 62 | "--" + Args.OUTPUT_NAME, 63 | help="Which output type to read and write", required=True, type=str, 64 | choices=[OUTPUT_CONFIDENCE, OUTPUT_DELTA]) 65 | parser.add_argument( 66 | "--" + Args.OUTPUT_DIR, required=True, 67 | help="Directory for the output features", type=str) 68 | args_dict = vars(parser.parse_args()) 69 | return args_dict 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /spivak/application/feature_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import itertools 6 | from pathlib import Path 7 | from typing import List 8 | 9 | import numpy as np 10 | 11 | 12 | class VideoFeatureInfo: 13 | 14 | def __init__(self, input_paths: List[Path], output_path: Path) -> None: 15 | self.input_paths = input_paths 16 | self.output_path = output_path 17 | 18 | 19 | def make_output_directories( 20 | video_feature_infos: List[VideoFeatureInfo]) -> None: 21 | for video_feature_info in video_feature_infos: 22 | video_feature_info.output_path.parent.mkdir(exist_ok=True, parents=True) 23 | 24 | 25 | def create_video_feature_infos( 26 | input_dirs: List[Path], input_feature_names: List[str], 27 | output_dir: Path, feature_name: str) -> List[VideoFeatureInfo]: 28 | all_features_input_paths = [] 29 | for current_input_dir, current_input_feature_name in zip( 30 | input_dirs, input_feature_names): 31 | current_input_paths = sorted( 32 | current_input_dir.glob(f"**/*{current_input_feature_name}.npy")) 33 | all_features_input_paths.append(current_input_paths) 34 | video_feature_infos = [] 35 | for video_index, _ in enumerate(all_features_input_paths[0]): 36 | first_input_path = all_features_input_paths[0][video_index] 37 | relative_dir = first_input_path.parent.relative_to(input_dirs[0]) 38 | video_feature_input_paths = [ 39 | current_input_paths[video_index] 40 | for current_input_paths in all_features_input_paths] 41 | video_feature_info = _create_video_feature_info( 42 | video_feature_input_paths, relative_dir, output_dir, feature_name) 43 | video_feature_infos.append(video_feature_info) 44 | return video_feature_infos 45 | 46 | 47 | def read_and_concatenate_features( 48 | features_dir: Path, game_list: List[Path], 49 | feature_name: str) -> np.ndarray: 50 | features_iterator = itertools.chain.from_iterable( 51 | _read_game_features(features_dir / game, feature_name) 52 | for game in game_list) 53 | return np.concatenate(list(features_iterator)) 54 | 55 | 56 | def _create_video_feature_info( 57 | input_paths: List[Path], relative_dir: Path, output_dir: Path, 58 | feature_name: str) -> VideoFeatureInfo: 59 | game_half = input_paths[0].stem[0] 60 | output_path = output_dir / relative_dir / f"{game_half}_{feature_name}.npy" 61 | return VideoFeatureInfo(input_paths, output_path) 62 | 63 | 64 | def _read_game_features( 65 | game_dir: Path, feature_name: str) -> List[np.ndarray]: 66 | if not game_dir.is_dir(): 67 | raise ValueError(f"Could not find game directory: {game_dir}") 68 | print(f"Reading features from {game_dir}") 69 | first = np.load(str(game_dir / f"1_{feature_name}.npy")) 70 | second = np.load(str(game_dir / f"2_{feature_name}.npy")) 71 | return [first, second] 72 | -------------------------------------------------------------------------------- /spivak/models/delta_dense_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | 9 | from spivak.data.dataset import VideoDatum 10 | from spivak.data.output_names import OUTPUT_DETECTION_SCORE 11 | from spivak.models.dense_predictor import DensePredictor, \ 12 | create_detection_scores, OUTPUT_CONFIDENCE 13 | from spivak.models.non_maximum_suppression import FlexibleNonMaximumSuppression 14 | from spivak.models.predictor import PredictorInterface, VideoOutputs 15 | 16 | 17 | class DeltaDensePredictor(PredictorInterface): 18 | 19 | def __init__( 20 | self, dense_predictor: DensePredictor, 21 | confidence_dir: Path) -> None: 22 | self._dense_predictor = dense_predictor 23 | self._confidence_dir = confidence_dir 24 | 25 | def predict_video(self, video_datum: VideoDatum) -> VideoOutputs: 26 | video_outputs = self._dense_predictor.predict_video_base( 27 | video_datum.features) 28 | confidence_path = DeltaDensePredictor.confidence_path( 29 | self._confidence_dir, video_datum.relative_path) 30 | video_outputs[OUTPUT_CONFIDENCE] = np.load(str(confidence_path)) 31 | # Fix possible mismatch in number of frames. 32 | video_outputs = _fix_video_outputs_frames(video_outputs) 33 | video_outputs[OUTPUT_DETECTION_SCORE] = create_detection_scores( 34 | video_outputs, throw_out_delta=False) 35 | return video_outputs 36 | 37 | def predict_video_and_save( 38 | self, video_datum: VideoDatum, nms: FlexibleNonMaximumSuppression, 39 | base_path: Path) -> None: 40 | video_outputs = self.predict_video(video_datum) 41 | DensePredictor.save_predictions(video_outputs, nms, base_path) 42 | DensePredictor.save_labels(video_datum, base_path) 43 | 44 | def load_weights(self, weights_path: str) -> None: 45 | self._dense_predictor.load_weights(weights_path) 46 | 47 | def save_model(self, model_path: str) -> None: 48 | self._dense_predictor.save_model(model_path) 49 | 50 | @staticmethod 51 | def confidence_path(confidence_dir: Path, relative_path: Path) -> Path: 52 | base_path = confidence_dir / relative_path 53 | return base_path.parent / f"{base_path.stem}_{OUTPUT_CONFIDENCE}.npy" 54 | 55 | 56 | def clip_frames(unclipped: np.ndarray, clipped_num_frames: int) -> np.ndarray: 57 | unclipped_num_frames = unclipped.shape[0] 58 | if unclipped_num_frames == clipped_num_frames: 59 | return unclipped 60 | assert clipped_num_frames == unclipped_num_frames - 1 61 | return unclipped[:clipped_num_frames] 62 | 63 | 64 | def _fix_video_outputs_frames(video_outputs: VideoOutputs) -> VideoOutputs: 65 | min_num_frames = min(output.shape[0] for output in video_outputs.values()) 66 | return { 67 | output_name: clip_frames(output, min_num_frames) 68 | for output_name, output in video_outputs.items()} 69 | -------------------------------------------------------------------------------- /spivak/evaluation/aggregate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import os 6 | import pickle 7 | from io import StringIO 8 | from typing import Optional, Dict, List 9 | 10 | from spivak.evaluation.segmentation_evaluation import SegmentationEvaluation 11 | from spivak.evaluation.spotting_evaluation import SpottingEvaluation 12 | from spivak.evaluation.task_evaluation import TaskEvaluation 13 | 14 | EVALUATION_AGGREGATE_TEXT_FILE_NAME = "evaluation_aggregate.txt" 15 | EVALUATION_AGGREGATE_PICKLE_FILE_NAME = "evaluation_aggregate.pkl" 16 | 17 | 18 | class EvaluationAggregate: 19 | 20 | """Aggregates and summarizes the evaluation results from different tasks.""" 21 | 22 | def __init__( 23 | self, spotting_evaluation: Optional[SpottingEvaluation], 24 | segmentation_evaluation: Optional[SegmentationEvaluation]) -> None: 25 | self.spotting_evaluation = spotting_evaluation 26 | self.segmentation_evaluation = segmentation_evaluation 27 | if spotting_evaluation: 28 | tolerances_name = spotting_evaluation.main_tolerances_name 29 | self.main_metric = spotting_evaluation.average_map_dict[ 30 | tolerances_name] 31 | self.main_metric_name = ( 32 | f"{SpottingEvaluation.METRIC_AVERAGE_MAP}_{tolerances_name}") 33 | elif segmentation_evaluation: 34 | self.main_metric = segmentation_evaluation.mean_iou 35 | self.main_metric_name = SegmentationEvaluation.METRIC_MEAN_IOU 36 | else: 37 | self.main_metric = 0.0 38 | self.main_metric_name = None 39 | optional_task_evaluations = [ 40 | spotting_evaluation, segmentation_evaluation] 41 | self._task_evaluations: List[TaskEvaluation] = [ 42 | task_evaluation for task_evaluation in optional_task_evaluations 43 | if task_evaluation] 44 | 45 | def scalars_for_logging(self) -> Dict[str, float]: 46 | return { 47 | key: value 48 | for task_evaluation in self._task_evaluations 49 | for key, value in task_evaluation.scalars_for_logging().items() 50 | } 51 | 52 | def save_txt(self, save_dir: str) -> None: 53 | save_path = os.path.join(save_dir, EVALUATION_AGGREGATE_TEXT_FILE_NAME) 54 | with open(save_path, "w") as txt_file: 55 | txt_file.write(self.__str__()) 56 | 57 | def save_pkl(self, save_dir: str) -> None: 58 | save_path = os.path.join( 59 | save_dir, EVALUATION_AGGREGATE_PICKLE_FILE_NAME) 60 | with open(save_path, "wb") as pkl_file: 61 | pickle.dump(self, pkl_file) 62 | 63 | def __str__(self): 64 | with StringIO() as str_io: 65 | self._write_str(str_io) 66 | summary = str_io.getvalue() 67 | return summary 68 | 69 | def _write_str(self, str_io: StringIO) -> None: 70 | for task_evaluation in self._task_evaluations: 71 | str_io.write(task_evaluation.summary()) 72 | str_io.write("\n") 73 | -------------------------------------------------------------------------------- /data/splits/SoccerNetCameraChangesValid.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_epl": { 3 | "2014-2015": [ 4 | "2015-02-21 - 18-00 Chelsea 1 - 1 Burnley", 5 | "2015-04-11 - 19-30 Burnley 0 - 1 Arsenal" 6 | ], 7 | "2015-2016": [ 8 | "2015-09-26 - 17-00 Manchester United 3 - 0 Sunderland", 9 | "2015-12-28 - 20-30 Manchester United 0 - 0 Chelsea", 10 | "2016-03-02 - 23-00 Liverpool 3 - 0 Manchester City" 11 | ], 12 | "2016-2017": [ 13 | "2016-09-10 - 17-00 Arsenal 2 - 1 Southampton", 14 | "2016-10-01 - 14-30 Swansea 1 - 2 Liverpool", 15 | "2016-10-17 - 22-00 Liverpool 0 - 0 Manchester United", 16 | "2016-10-22 - 17-00 Arsenal 0 - 0 Middlesbrough", 17 | "2016-10-22 - 19-30 Liverpool 2 - 1 West Brom", 18 | "2016-10-29 - 14-30 Sunderland 1 - 4 Arsenal", 19 | "2016-10-29 - 19-30 Crystal Palace 2 - 4 Liverpool" 20 | ] 21 | }, 22 | "europe_uefa-champions-league": { 23 | "2014-2015": [ 24 | "2014-12-09 - 22-45 Liverpool 1 - 1 Basel" 25 | ], 26 | "2015-2016": [ 27 | "2015-09-15 - 21-45 Paris SG 2 - 0 Malmo FF", 28 | "2015-09-16 - 21-45 Chelsea 4 - 0 Maccabi Tel Aviv", 29 | "2015-09-16 - 21-45 Dyn. Kiev 2 - 2 FC Porto", 30 | "2015-10-20 - 21-45 Arsenal 2 - 0 Bayern Munich" 31 | ], 32 | "2016-2017": [ 33 | "2017-04-12 - 21-45 Bayern Munich 1 - 2 Real Madrid" 34 | ] 35 | }, 36 | "france_ligue-1": { 37 | "2015-2016": [ 38 | "2015-09-19 - 18-30 Reims 1 - 1 Paris SG" 39 | ] 40 | }, 41 | "germany_bundesliga": { 42 | "2014-2015": [ 43 | "2015-04-11 - 16-30 Bayern Munich 3 - 0 Eintracht Frankfurt" 44 | ] 45 | }, 46 | "italy_serie-a": { 47 | "2014-2015": [ 48 | "2015-04-19 - 21-45 Inter 0 - 0 AC Milan" 49 | ], 50 | "2016-2017": [ 51 | "2016-08-21 - 21-45 Pescara 2 - 2 Napoli", 52 | "2016-11-06 - 17-00 Palermo 1 - 2 AC Milan", 53 | "2017-01-15 - 17-00 Udinese 0 - 1 AS Roma", 54 | "2017-01-29 - 17-00 Sampdoria 3 - 2 AS Roma", 55 | "2017-02-19 - 20-00 AS Roma 4 - 1 Torino", 56 | "2017-02-26 - 22-45 Inter 1 - 3 AS Roma", 57 | "2017-04-09 - 16-00 Bologna 0 - 3 AS Roma", 58 | "2017-04-15 - 16-00 AS Roma 1 - 1 Atalanta", 59 | "2017-05-20 - 19-00 Chievo 3 - 5 AS Roma" 60 | ] 61 | }, 62 | "spain_laliga": { 63 | "2014-2015": [ 64 | "2015-02-21 - 18-00 Barcelona 0 - 1 Malaga" 65 | ], 66 | "2015-2016": [ 67 | "2015-08-29 - 23-30 Real Madrid 5 - 0 Betis", 68 | "2015-09-12 - 21-30 Atl. Madrid 1 - 2 Barcelona", 69 | "2015-09-23 - 22-00 Ath Bilbao 1 - 2 Real Madrid", 70 | "2015-09-26 - 19-15 Real Madrid 0 - 0 Malaga", 71 | "2015-10-04 - 21-30 Atl. Madrid 1 - 1 Real Madrid", 72 | "2016-05-08 - 18-00 Real Madrid 3 - 2 Valencia", 73 | "2016-05-14 - 18-00 Dep. La Coruna 0 - 2 Real Madrid" 74 | ], 75 | "2016-2017": [ 76 | "2017-01-07 - 15-00 Real Madrid 5 - 0 Granada CF", 77 | "2017-01-15 - 22-45 Sevilla 2 - 1 Real Madrid" 78 | ] 79 | } 80 | } -------------------------------------------------------------------------------- /data/splits/SoccerNetCameraChangesTest.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_epl": { 3 | "2014-2015": [ 4 | "2015-02-22 - 19-15 Southampton 0 - 2 Liverpool" 5 | ], 6 | "2015-2016": [ 7 | "2015-10-03 - 19-30 Chelsea 1 - 3 Southampton", 8 | "2015-12-05 - 20-30 Chelsea 0 - 1 Bournemouth", 9 | "2016-01-03 - 16-30 Crystal Palace 0 - 3 Chelsea", 10 | "2016-01-23 - 20-30 West Ham 2 - 2 Manchester City", 11 | "2016-02-07 - 19-00 Chelsea 1 - 1 Manchester United", 12 | "2016-02-13 - 20-30 Chelsea 5 - 1 Newcastle Utd" 13 | ], 14 | "2016-2017": [ 15 | "2016-08-14 - 18-00 Arsenal 3 - 4 Liverpool", 16 | "2016-10-02 - 18-30 Burnley 0 - 1 Arsenal", 17 | "2016-11-06 - 17-15 Liverpool 6 - 1 Watford", 18 | "2016-11-19 - 18-00 Southampton 0 - 0 Liverpool", 19 | "2016-12-10 - 20-30 Leicester 4 - 2 Manchester City" 20 | ] 21 | }, 22 | "europe_uefa-champions-league": { 23 | "2014-2015": [ 24 | "2014-11-04 - 22-45 Dortmund 4 - 1 Galatasaray" 25 | ], 26 | "2015-2016": [ 27 | "2015-09-15 - 21-45 Sevilla 3 - 0 B. Monchengladbach", 28 | "2015-09-16 - 21-45 Bayer Leverkusen 4 - 1 BATE", 29 | "2015-09-16 - 21-45 Olympiakos Piraeus 0 - 3 Bayern Munich", 30 | "2015-09-29 - 21-45 Bayern Munich 5 - 0 D. Zagreb", 31 | "2015-11-03 - 18-00 FC Astana 0 - 0 Atl. Madrid" 32 | ], 33 | "2016-2017": [ 34 | "2016-11-23 - 22-45 Arsenal 2 - 2 Paris SG", 35 | "2016-11-23 - 22-45 Celtic 0 - 2 Barcelona", 36 | "2017-03-08 - 22-45 Barcelona 6 - 1 Paris SG" 37 | ] 38 | }, 39 | "france_ligue-1": { 40 | "2016-2017": [ 41 | "2016-08-21 - 21-45 Paris SG 3 - 0 Metz" 42 | ] 43 | }, 44 | "germany_bundesliga": { 45 | "2014-2015": [ 46 | "2015-04-25 - 16-30 Dortmund 2 - 0 Eintracht Frankfurt" 47 | ], 48 | "2015-2016": [ 49 | "2015-09-26 - 16-30 1. FSV Mainz 05 0 - 3 Bayern Munich" 50 | ], 51 | "2016-2017": [ 52 | "2016-09-17 - 16-30 Dortmund 6 - 0 Darmstadt" 53 | ] 54 | }, 55 | "italy_serie-a": { 56 | "2016-2017": [ 57 | "2016-08-27 - 21-45 Napoli 4 - 2 AC Milan", 58 | "2016-10-30 - 17-00 Empoli 0 - 0 AS Roma", 59 | "2016-11-05 - 22-45 Napoli 1 - 1 Lazio", 60 | "2017-02-07 - 22-45 AS Roma 4 - 0 Fiorentina", 61 | "2017-02-12 - 14-30 Crotone 0 - 2 AS Roma", 62 | "2017-03-12 - 22-45 Palermo 0 - 3 AS Roma", 63 | "2017-04-01 - 21-45 AS Roma 2 - 0 Empoli" 64 | ] 65 | }, 66 | "spain_laliga": { 67 | "2014-2015": [ 68 | "2015-02-14 - 20-00 Real Madrid 2 - 0 Dep. La Coruna", 69 | "2015-02-22 - 23-00 Elche 0 - 2 Real Madrid", 70 | "2015-04-11 - 17-00 Real Madrid 3 - 0 Eibar", 71 | "2015-04-25 - 17-00 Espanyol 0 - 2 Barcelona", 72 | "2015-05-02 - 19-00 Atl. Madrid 0 - 0 Ath Bilbao", 73 | "2015-05-17 - 20-00 Atl. Madrid 0 - 1 Barcelona" 74 | ], 75 | "2016-2017": [ 76 | "2017-01-08 - 22-45 Villarreal 1 - 1 Barcelona", 77 | "2017-04-15 - 21-45 Barcelona 3 - 2 Real Sociedad" 78 | ] 79 | } 80 | } -------------------------------------------------------------------------------- /data/splits/SoccerNetGamesChallenge.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_epl": { 3 | "2016-2017": [ 4 | "2017-05-13 - 18-00 Stoke City 1 - 4 Arsenal", 5 | "2017-04-30 - 18-00 Tottenham 2 - 0 Arsenal", 6 | "2017-04-10 - 18-00 Crystal Palace 3 - 0 Arsenal", 7 | "2017-04-02 - 18-00 Arsenal 2 - 2 Manchester City", 8 | "2017-03-18 - 18-00 West Brom 3 - 1 Arsenal", 9 | "2017-02-04 - 18-00 Chelsea 3 - 1 Arsenal", 10 | "2017-01-22 - 18-00 Arsenal 2 - 1 Burnley", 11 | "2017-01-03 - 18-00 Bournemouth 3 - 3 Arsenal", 12 | "2016-11-19 - 18-00 Manchester United 1 - 1 Arsenal" 13 | ] 14 | }, 15 | "europe_uefa-champions-league": { 16 | "2016-2017": [ 17 | "2017-05-09 - 18-00 Juventus 2 - 1 Monaco", 18 | "2017-04-11 - 18-00 Juventus 3 - 0 Barcelona", 19 | "2017-03-07 - 18-00 Arsenal 1 - 5 Bayern", 20 | "2017-02-15 - 18-00 Bayern 5 - 1 Arsenal", 21 | "2016-12-06 - 18-00 Basel 1 - 4 Arsenal", 22 | "2016-11-06 - 18-00 Ludogorets Razgrad 2 - 3 Arsenal", 23 | "2016-09-26 - 18-00 Arsenal 2 - 0 Basel" 24 | ] 25 | }, 26 | "france_ligue-1": { 27 | "2016-2017": [ 28 | "2017-02-24 - 18-00 Nice 2 - 1 Montpellier", 29 | "2016-09-11 - 18-00 Nice 3 - 2 Marseille", 30 | "2017-04-02 - 18-00 Nice 2 - 1 Bordeaux", 31 | "2017-05-07 - 18-00 Marseille 2 - 1 Nice", 32 | "2017-02-18 - 18-00 Lorient 0 - 1 Nice", 33 | "2017-02-08 - 18-00 Nice 1 - 0 Saint Etienne", 34 | "2016-12-11 - 18-00 Paris SG 2 - 2 Nice", 35 | "2016-12-04 - 18-00 Nice 3 - 0 Toulouse", 36 | "2016-11-20 - 18-00 Saint Etienne 0 - 1 Nice" 37 | ] 38 | }, 39 | "germany_bundesliga": { 40 | "2016-2017": [ 41 | "2016-09-30 - 18-00 Leipzig 2 - 1 Augsburg", 42 | "2016-09-21 - 18-00 Leipzig 1 - 1 Borussia", 43 | "2017-05-13 - 18-00 Leipzig 4 - 5 Bayern", 44 | "2017-02-11 - 18-00 Leipzig 0 - 3 Hamburger", 45 | "2017-01-28 - 18-00 Leipzig 2 - 1 Hoffenheim", 46 | "2017-01-21 - 18-00 Leipzig 3 - 0 Frankfurt", 47 | "2016-12-20 - 18-00 Bayern 3 - 0 Leipzig", 48 | "2016-12-17 - 18-00 Leipzig 2 - 0 Hertha" 49 | ] 50 | }, 51 | "italy_serie-a": { 52 | "2016-2017": [ 53 | "2016-09-24 - 18-00 Palermo 0 - 1 Juventus", 54 | "2016-09-18 - 18-00 Empoli 0 - 2 Internazionale", 55 | "2016-09-18 - 18-00 Juventus 1 - 0 Internazionale", 56 | "2016-08-27 - 18-00 Lazio 0 - 1 Juventus", 57 | "2017-03-18 - 18-00 Torino 2 - 2 Internazionale", 58 | "2017-03-05 - 18-00 Cagliari 1 - 5 Internazionale", 59 | "2016-08-20 - 18-00 Juventus 2 - 1 Fiorentina", 60 | "2016-12-01 - 18-00 Napoli 3 - 0 Internazionale", 61 | "2016-11-06 - 18-00 Internazionale 3 - 0 Crotone" 62 | ] 63 | }, 64 | "spain_laliga": { 65 | "2019-2020": [ 66 | "2020-02-01 - 18-00 Real Madrid 1 - 0 Atletico Madrid", 67 | "2019-10-05 - 18-00 Real Madrid 4 - 2 Granada", 68 | "2019-09-14 - 18-00 Barcelona 5 - 2 Valencia", 69 | "2020-02-16 - 18-00 Real Madrid 2 - 2 Celta Vigo", 70 | "2020-02-09 - 18-00 Osasuna 1 - 4 Real Madrid", 71 | "2019-12-15 - 18-00 Valencia 1 - 1 Real Madrid", 72 | "2019-12-14 - 18-00 Real Sociedad 2 - 2 Barcelona", 73 | "2019-08-17 - 18-00 Celta Vigo 1 - 3 Real Madrid" 74 | ] 75 | } 76 | } -------------------------------------------------------------------------------- /data/splits/SoccerNetCameraChangesChallenge.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_epl": { 3 | "2016-2017": [ 4 | "2017-05-13 - 18-00 Stoke City 1 - 4 Arsenal", 5 | "2017-04-30 - 18-00 Tottenham 2 - 0 Arsenal", 6 | "2017-04-10 - 18-00 Crystal Palace 3 - 0 Arsenal", 7 | "2017-04-02 - 18-00 Arsenal 2 - 2 Manchester City", 8 | "2017-03-18 - 18-00 West Brom 3 - 1 Arsenal", 9 | "2017-02-04 - 18-00 Chelsea 3 - 1 Arsenal", 10 | "2017-01-22 - 18-00 Arsenal 2 - 1 Burnley", 11 | "2017-01-03 - 18-00 Bournemouth 3 - 3 Arsenal", 12 | "2016-11-19 - 18-00 Manchester United 1 - 1 Arsenal" 13 | ] 14 | }, 15 | "europe_uefa-champions-league": { 16 | "2016-2017": [ 17 | "2017-05-09 - 18-00 Juventus 2 - 1 Monaco", 18 | "2017-04-11 - 18-00 Juventus 3 - 0 Barcelona", 19 | "2017-03-07 - 18-00 Arsenal 1 - 5 Bayern", 20 | "2017-02-15 - 18-00 Bayern 5 - 1 Arsenal", 21 | "2016-12-06 - 18-00 Basel 1 - 4 Arsenal", 22 | "2016-11-06 - 18-00 Ludogorets Razgrad 2 - 3 Arsenal", 23 | "2016-09-26 - 18-00 Arsenal 2 - 0 Basel" 24 | ] 25 | }, 26 | "france_ligue-1": { 27 | "2016-2017": [ 28 | "2017-02-24 - 18-00 Nice 2 - 1 Montpellier", 29 | "2016-09-11 - 18-00 Nice 3 - 2 Marseille", 30 | "2017-04-02 - 18-00 Nice 2 - 1 Bordeaux", 31 | "2017-05-07 - 18-00 Marseille 2 - 1 Nice", 32 | "2017-02-18 - 18-00 Lorient 0 - 1 Nice", 33 | "2017-02-08 - 18-00 Nice 1 - 0 Saint Etienne", 34 | "2016-12-11 - 18-00 Paris SG 2 - 2 Nice", 35 | "2016-12-04 - 18-00 Nice 3 - 0 Toulouse", 36 | "2016-11-20 - 18-00 Saint Etienne 0 - 1 Nice" 37 | ] 38 | }, 39 | "germany_bundesliga": { 40 | "2016-2017": [ 41 | "2016-09-30 - 18-00 Leipzig 2 - 1 Augsburg", 42 | "2016-09-21 - 18-00 Leipzig 1 - 1 Borussia", 43 | "2017-05-13 - 18-00 Leipzig 4 - 5 Bayern", 44 | "2017-02-11 - 18-00 Leipzig 0 - 3 Hamburger", 45 | "2017-01-28 - 18-00 Leipzig 2 - 1 Hoffenheim", 46 | "2017-01-21 - 18-00 Leipzig 3 - 0 Frankfurt", 47 | "2016-12-20 - 18-00 Bayern 3 - 0 Leipzig", 48 | "2016-12-17 - 18-00 Leipzig 2 - 0 Hertha" 49 | ] 50 | }, 51 | "italy_serie-a": { 52 | "2016-2017": [ 53 | "2016-09-24 - 18-00 Palermo 0 - 1 Juventus", 54 | "2016-09-18 - 18-00 Empoli 0 - 2 Internazionale", 55 | "2016-09-18 - 18-00 Juventus 1 - 0 Internazionale", 56 | "2016-08-27 - 18-00 Lazio 0 - 1 Juventus", 57 | "2017-03-18 - 18-00 Torino 2 - 2 Internazionale", 58 | "2017-03-05 - 18-00 Cagliari 1 - 5 Internazionale", 59 | "2016-08-20 - 18-00 Juventus 2 - 1 Fiorentina", 60 | "2016-12-01 - 18-00 Napoli 3 - 0 Internazionale", 61 | "2016-11-06 - 18-00 Internazionale 3 - 0 Crotone" 62 | ] 63 | }, 64 | "spain_laliga": { 65 | "2019-2020": [ 66 | "2020-02-01 - 18-00 Real Madrid 1 - 0 Atletico Madrid", 67 | "2019-10-05 - 18-00 Real Madrid 4 - 2 Granada", 68 | "2019-09-14 - 18-00 Barcelona 5 - 2 Valencia", 69 | "2020-02-16 - 18-00 Real Madrid 2 - 2 Celta Vigo", 70 | "2020-02-09 - 18-00 Osasuna 1 - 4 Real Madrid", 71 | "2019-12-15 - 18-00 Valencia 1 - 1 Real Madrid", 72 | "2019-12-14 - 18-00 Real Sociedad 2 - 2 Barcelona", 73 | "2019-08-17 - 18-00 Celta Vigo 1 - 3 Real Madrid" 74 | ] 75 | } 76 | } -------------------------------------------------------------------------------- /bin/create_normalizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import pickle 9 | from pathlib import Path 10 | from typing import Dict 11 | 12 | import numpy as np 13 | from sklearn import preprocessing 14 | 15 | from spivak.application.feature_utils import read_and_concatenate_features 16 | from spivak.data.dataset_splits import SPLIT_KEY_TRAIN 17 | from spivak.data.soccernet_reader import GamePathsReader 18 | 19 | NORMALIZER_MIN_MAX = "min_max" 20 | NORMALIZER_STANDARD = "standard" 21 | NORMALIZER_STANDARD_NO_MEAN = "standard_no_mean" 22 | NORMALIZER_MAX_ABS = "max_abs" 23 | 24 | 25 | class Args: 26 | FEATURES_DIR = "features_dir" 27 | SPLITS_DIR = "splits_dir" 28 | NORMALIZER = "normalizer" 29 | FEATURE_NAME = "feature_name" 30 | OUT_PATH = "out_path" 31 | 32 | 33 | def main(): 34 | args = _get_command_line_arguments() 35 | feature_name = args[Args.FEATURE_NAME] 36 | normalizer_type = args[Args.NORMALIZER] 37 | features_dir = Path(args[Args.FEATURES_DIR]) 38 | splits_dir = Path(args[Args.SPLITS_DIR]) 39 | print("Reading all the features") 40 | features = _read_features(features_dir, splits_dir, feature_name) 41 | # Create the normalizer and fit it to the features. 42 | print("Creating the normalizer") 43 | if normalizer_type == NORMALIZER_MIN_MAX: 44 | normalizer = preprocessing.MinMaxScaler() 45 | elif normalizer_type == NORMALIZER_STANDARD: 46 | normalizer = preprocessing.StandardScaler() 47 | elif normalizer_type == NORMALIZER_STANDARD_NO_MEAN: 48 | normalizer = preprocessing.StandardScaler(with_mean=False) 49 | elif normalizer_type == NORMALIZER_MAX_ABS: 50 | normalizer = preprocessing.MaxAbsScaler() 51 | else: 52 | raise ValueError(f"Unknown normalizer type {normalizer_type}") 53 | normalizer.fit(features) 54 | # Write out the normalizer. 55 | out_path = Path(args[Args.OUT_PATH]) 56 | print(f"Saving normalizer to {out_path}") 57 | with out_path.open("wb") as out_file: 58 | pickle.dump(normalizer, out_file) 59 | 60 | 61 | def _get_command_line_arguments() -> Dict: 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument( 64 | "--" + Args.OUT_PATH, required=True, type=str, 65 | help="Output pickle file path") 66 | parser.add_argument( 67 | "--" + Args.FEATURES_DIR, required=True, type=str, 68 | help="Directory from which to read the video features") 69 | parser.add_argument( 70 | "--" + Args.SPLITS_DIR, required=True, type=str, 71 | help="Directory containing splits information") 72 | parser.add_argument( 73 | "--" + Args.FEATURE_NAME, required=True, type=str, 74 | help="What type of features to read") 75 | parser.add_argument( 76 | "--" + Args.NORMALIZER, required=True, type=str, 77 | choices=[NORMALIZER_STANDARD, NORMALIZER_STANDARD_NO_MEAN, 78 | NORMALIZER_MIN_MAX, NORMALIZER_MAX_ABS], 79 | help="Type of the normalizer") 80 | args_dict = vars(parser.parse_args()) 81 | return args_dict 82 | 83 | 84 | def _read_features( 85 | features_dir: Path, splits_dir: Path, feature_name: str) -> np.ndarray: 86 | # Get the list of games from the standard training split. Don't need to 87 | # involve validation data, since the training features should be enough to 88 | # give us good statistics. 89 | game_list = GamePathsReader.read_game_list_v2(splits_dir, SPLIT_KEY_TRAIN) 90 | return read_and_concatenate_features(features_dir, game_list, feature_name) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /bin/create_multi_task_games_csv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import csv 9 | from pathlib import Path 10 | from typing import Dict, Set, List 11 | 12 | from spivak.data.dataset import Task 13 | from spivak.data.dataset_splits import SPLIT_KEY_VALIDATION, SPLIT_KEY_TEST, \ 14 | SPLIT_KEY_TRAIN 15 | from spivak.data.soccernet_reader import GamePathsReader, \ 16 | GAMES_CSV_COLUMN_GAME, GAMES_CSV_COLUMN_SPLIT_KEY 17 | 18 | OUT_CSV_PATH = Path("SoccerNetSpottingAndCameraChangesLarge.csv") 19 | ARG_SPLITS_DIR = "splits_dir" 20 | RELEVANT_SPLITS = [SPLIT_KEY_TRAIN, SPLIT_KEY_VALIDATION, SPLIT_KEY_TEST] 21 | OUT_FIELDS = [ 22 | GAMES_CSV_COLUMN_GAME, GAMES_CSV_COLUMN_SPLIT_KEY, Task.SPOTTING.name, 23 | Task.SEGMENTATION.name] 24 | 25 | 26 | def main() -> None: 27 | args = _get_command_line_arguments() 28 | splits_dir = Path(args[ARG_SPLITS_DIR]) 29 | spotting_game_paths = _read_spotting_game_paths_dict(splits_dir) 30 | segmentation_game_paths_set = _read_segmentation_game_paths_set(splits_dir) 31 | out_rows = _prepare_out_rows( 32 | spotting_game_paths, segmentation_game_paths_set) 33 | _write_rows(out_rows) 34 | _print_summary(out_rows) 35 | 36 | 37 | def _read_spotting_game_paths_dict(splits_dir: Path) -> Dict[str, List[Path]]: 38 | spotting_game_paths = dict() 39 | for split_key in RELEVANT_SPLITS: 40 | spotting_game_paths[split_key] = GamePathsReader.read_game_list_v2( 41 | splits_dir, split_key) 42 | return spotting_game_paths 43 | 44 | 45 | def _read_segmentation_game_paths_set(splits_dir: Path) -> Set[Path]: 46 | segmentation_game_paths = set() 47 | for split_key in RELEVANT_SPLITS: 48 | split_game_paths = \ 49 | GamePathsReader.read_game_list_v2_camera_segmentation( 50 | splits_dir, split_key) 51 | segmentation_game_paths.update(split_game_paths) 52 | return segmentation_game_paths 53 | 54 | 55 | def _prepare_out_rows( 56 | spotting_game_paths: Dict[str, List[Path]], 57 | segmentation_game_paths_set: Set[Path]) -> List[Dict]: 58 | out_rows = [] 59 | for split_key in RELEVANT_SPLITS: 60 | split_games = spotting_game_paths[split_key] 61 | for game_path in split_games: 62 | use_segmentation = (game_path in segmentation_game_paths_set) 63 | row = { 64 | GAMES_CSV_COLUMN_GAME: game_path, 65 | GAMES_CSV_COLUMN_SPLIT_KEY: split_key, 66 | Task.SEGMENTATION.name: int(use_segmentation), 67 | Task.SPOTTING.name: 1} 68 | out_rows.append(row) 69 | return out_rows 70 | 71 | 72 | def _write_rows(out_rows: List[Dict]) -> None: 73 | print(f"Writing output to {OUT_CSV_PATH}") 74 | with OUT_CSV_PATH.open("w") as csv_file: 75 | writer = csv.DictWriter(csv_file, fieldnames=OUT_FIELDS) 76 | writer.writeheader() 77 | writer.writerows(out_rows) 78 | 79 | 80 | def _print_summary(out_rows: List[Dict]) -> None: 81 | print("Summary of game counts per split:") 82 | for split_key in RELEVANT_SPLITS: 83 | split_rows = [ 84 | row for row in out_rows 85 | if row[GAMES_CSV_COLUMN_SPLIT_KEY] == split_key] 86 | n_games = len(split_rows) 87 | n_segmentation = sum(row[Task.SEGMENTATION.name] for row in split_rows) 88 | n_spotting = sum(row[Task.SPOTTING.name] for row in split_rows) 89 | print(f"split: {split_key}, n_games: {n_games}, n_spotting: " 90 | f"{n_spotting}, n_segmentation: {n_segmentation}") 91 | 92 | 93 | def _get_command_line_arguments() -> Dict: 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument( 96 | "--" + ARG_SPLITS_DIR, help="Directory containing the splits", 97 | required=True) 98 | args_dict = vars(parser.parse_args()) 99 | return args_dict 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /spivak/models/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from abc import abstractmethod 6 | from typing import List, Optional 7 | 8 | from tensorflow.keras import Model 9 | from tensorflow.keras.callbacks import History 10 | from tensorflow.keras.optimizers import Optimizer 11 | 12 | from spivak.models.assembly.head import TrainerHeadInterface 13 | from spivak.models.tf_dataset import TFDataset 14 | 15 | 16 | class TrainerInterface: 17 | 18 | @abstractmethod 19 | def compile(self, optimizer: Optimizer) -> None: 20 | pass 21 | 22 | @abstractmethod 23 | def fit(self, initial_epoch: int, epochs: int, callbacks, 24 | validation_freq) -> History: 25 | pass 26 | 27 | @abstractmethod 28 | def save_model(self, model_path: str) -> None: 29 | pass 30 | 31 | @property 32 | @abstractmethod 33 | def model(self) -> Model: 34 | pass 35 | 36 | @property 37 | @abstractmethod 38 | def steps_per_epoch(self) -> int: 39 | pass 40 | 41 | 42 | class FittingDataset: 43 | 44 | def __init__(self, tf_dataset: TFDataset, batches_per_epoch: int): 45 | self.tf_dataset = tf_dataset 46 | self.batches_per_epoch = batches_per_epoch 47 | 48 | 49 | class DefaultTrainer(TrainerInterface): 50 | 51 | def __init__( 52 | self, model: Model, trainer_heads: List[TrainerHeadInterface], 53 | fitting_training_set: FittingDataset, 54 | fitting_validation_set: Optional[FittingDataset]) -> None: 55 | self._model = model 56 | self._trainer_heads = trainer_heads 57 | # Using fit with the Tensorflow Dataset was the fastest solution 58 | # I could find, so we use that here. It does require converting some 59 | # of the data pre-processing components to Tensorflow, which can be 60 | # annoying sometimes. 61 | self._fitting_training_set = fitting_training_set 62 | self._fitting_validation_set = fitting_validation_set 63 | 64 | def compile(self, optimizer: Optimizer) -> None: 65 | losses = [ 66 | trainer_head.predictor_head.loss 67 | for trainer_head in self._trainer_heads] 68 | loss_weights = [ 69 | trainer_head.predictor_head.loss_weight 70 | for trainer_head in self._trainer_heads] 71 | # Keras has some memory leaks that we are trying to work around. When 72 | # using run_eagerly=True, the leak is smaller. In theory, 73 | # it is slower, but I didn't notice a significant decrease in speed 74 | # in some experiments. 75 | self._model.compile( 76 | loss=losses, optimizer=optimizer, loss_weights=loss_weights, 77 | run_eagerly=True) 78 | 79 | def fit(self, initial_epoch: int, epochs: int, callbacks, 80 | validation_freq) -> History: 81 | if self._fitting_validation_set: 82 | validation_tf_dataset = self._fitting_validation_set.tf_dataset 83 | else: 84 | validation_tf_dataset = None 85 | return self._model.fit( 86 | self._fitting_training_set.tf_dataset, 87 | initial_epoch=initial_epoch, 88 | epochs=epochs, callbacks=callbacks, 89 | steps_per_epoch=self.steps_per_epoch, 90 | validation_data=validation_tf_dataset, 91 | validation_freq=validation_freq, 92 | validation_steps=self.validation_steps) 93 | 94 | def save_model(self, model_path: str) -> None: 95 | self._model.save(model_path) 96 | 97 | @property 98 | def model(self) -> Model: 99 | return self._model 100 | 101 | @property 102 | def steps_per_epoch(self) -> int: 103 | return self._fitting_training_set.batches_per_epoch 104 | 105 | @property 106 | def validation_steps(self) -> Optional[int]: 107 | if not self._fitting_validation_set: 108 | return None 109 | return self._fitting_validation_set.batches_per_epoch 110 | -------------------------------------------------------------------------------- /spivak/models/sam_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | # 5 | # This file incorporates work covered by the following copyright and permission 6 | # notice: 7 | # Copyright (c) 2021 Jannes Elstner 8 | # Licensed under the terms of the MIT license. 9 | # You may obtain a copy of the MIT License at https://opensource.org/licenses/MIT 10 | 11 | # This file contains pieces of code taken from the following file. 12 | # https://github.com/Jannoshh/simple-sam/blob/cd77d0217128aa4ad3dcea558403b5bca93b5952/sam.py 13 | # At Yahoo Inc., the code was modified and new code was added, in order to 14 | # address our specific use-case. 15 | # 16 | # Relevant paper: 17 | # Sharpness-Aware Minimization for Efficiently Improving Generalization 18 | # https://arxiv.org/pdf/2010.01412.pdf 19 | 20 | from types import MethodType 21 | 22 | import tensorflow as tf 23 | from tensorflow.keras import Model 24 | 25 | 26 | def maybe_convert_model_to_sam(model: Model, rho: float, eps: float) -> None: 27 | if rho > 0.0: 28 | # Monkey-patch the models object, since I wasn't able to load it to/from 29 | # disk using Keras due to some complications with the Model and 30 | # Functional classes. 31 | model._rho = rho 32 | model._eps = eps 33 | model.train_step = MethodType(SAMModel.train_step, model) 34 | 35 | 36 | class SAMModel(Model): 37 | 38 | def __init__( 39 | self, main_input, output_tensors, rho: float, eps: float) -> None: 40 | super(SAMModel, self).__init__(main_input, output_tensors) 41 | self._rho = rho 42 | self._eps = eps 43 | 44 | def train_step(self, data): 45 | # Unpack the data. Its structure depends on your models and 46 | # on what you pass to `fit()`. 47 | if len(data) == 3: 48 | x, y, sample_weight = data 49 | else: 50 | sample_weight = None 51 | x, y = data 52 | 53 | with tf.GradientTape() as tape: 54 | y_pred = self(x, training=True) # Forward pass 55 | # Compute the loss value 56 | # (the loss function is configured in `compile()`) 57 | loss = self.compiled_loss( 58 | y, y_pred, sample_weight=sample_weight, 59 | regularization_losses=self.losses) 60 | 61 | # Compute gradients 62 | trainable_vars = self.trainable_variables 63 | gradients = tape.gradient(loss, trainable_vars) 64 | 65 | # first step 66 | e_ws = [] 67 | grad_norm = tf.linalg.global_norm(gradients) 68 | for i in range(len(trainable_vars)): 69 | if gradients[i] is not None: 70 | e_w = tf.math.scalar_mul(self._rho, gradients[i]) / ( 71 | grad_norm + self._eps) 72 | else: 73 | e_w = tf.math.scalar_mul(0.0, trainable_vars[i]) 74 | trainable_vars[i].assign_add(e_w) 75 | e_ws.append(e_w) 76 | 77 | with tf.GradientTape() as tape: 78 | y_pred = self(x, training=True) # Forward pass 79 | # Compute the loss value 80 | # (the loss function is configured in `compile()`) 81 | loss = self.compiled_loss( 82 | y, y_pred, sample_weight=sample_weight, 83 | regularization_losses=self.losses) 84 | 85 | trainable_vars = self.trainable_variables 86 | gradients = tape.gradient(loss, trainable_vars) 87 | 88 | for i in range(len(trainable_vars)): 89 | trainable_vars[i].assign_sub(e_ws[i]) 90 | self.optimizer.apply_gradients(zip(gradients, trainable_vars)) 91 | 92 | # Update the metrics. 93 | # Metrics are configured in `compile()`. 94 | self.compiled_metrics.update_state( 95 | y, y_pred, sample_weight=sample_weight) 96 | 97 | # Return a dict mapping metric names to current value. 98 | # Note that it will include the loss (tracked in self.metrics). 99 | return {m.name: m.result() for m in self.metrics} 100 | -------------------------------------------------------------------------------- /bin/make_pca_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import pickle 9 | from pathlib import Path 10 | from typing import Dict 11 | 12 | import numpy as np 13 | from sklearn.decomposition import IncrementalPCA 14 | 15 | from spivak.application.feature_utils import read_and_concatenate_features 16 | from spivak.data.dataset_splits import SPLIT_KEY_TRAIN 17 | from spivak.data.soccernet_reader import GamePathsReader 18 | from spivak.feature_extraction.extraction import \ 19 | extractor_type_to_feature_name, EXTRACTOR_TYPE_RESNET_TF2 20 | 21 | # Estimate the PCA transform using some SoccerNet data. Save it to a numpy file. 22 | N_COMPONENTS = 512 23 | # In a simple experiment, whitening gave significantly worse results when 24 | # using ResNet features with the context-aware loss. 25 | DEFAULT_WHITEN = False 26 | 27 | 28 | class Args: 29 | FEATURES_DIR = "features_dir" 30 | SPLITS_DIR = "splits_dir" 31 | OUT_PATH = "out_path" 32 | FEATURES = "features" 33 | WHITEN = "whiten" 34 | NO_WHITEN = "no_whiten" 35 | 36 | 37 | def main(): 38 | args = _get_command_line_arguments() 39 | extractor_type = args[Args.FEATURES] 40 | features_dir = Path(args[Args.FEATURES_DIR]) 41 | splits_dir = Path(args[Args.SPLITS_DIR]) 42 | features = _read_soccernet_features( 43 | features_dir, splits_dir, extractor_type) 44 | # Compute the transform. 45 | whiten = args[Args.WHITEN] 46 | print("Creating the PCA transform") 47 | incremental_pca = IncrementalPCA(n_components=N_COMPONENTS, whiten=whiten) 48 | incremental_pca.fit(features) 49 | # Write out the PCA transform 50 | out_path = args[Args.OUT_PATH] 51 | if not out_path: 52 | out_path = _create_pca_file_path( 53 | features_dir, extractor_type, N_COMPONENTS, whiten) 54 | print(f"Saving result to {out_path}") 55 | with open(out_path, "wb") as out_file: 56 | pickle.dump(incremental_pca, out_file) 57 | 58 | 59 | def _get_command_line_arguments() -> Dict: 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument( 62 | "--" + Args.OUT_PATH, help="Optional: output file path", required=False) 63 | parser.add_argument( 64 | "--" + Args.FEATURES_DIR, required=True, 65 | help="Directory in which to store intermediate video features") 66 | parser.add_argument( 67 | "--" + Args.FEATURES, required=False, 68 | help="What type of features to use", default=EXTRACTOR_TYPE_RESNET_TF2, 69 | choices=[EXTRACTOR_TYPE_RESNET_TF2]) 70 | parser.add_argument( 71 | "--" + Args.SPLITS_DIR, required=True, type=str, 72 | help="Directory containing splits information") 73 | parser.add_argument( 74 | "--" + Args.WHITEN, help="Whiten the transformation", required=False, 75 | action='store_true', dest=Args.WHITEN) 76 | parser.add_argument( 77 | "--" + Args.NO_WHITEN, help="Don't whiten the transformation", 78 | required=False, action='store_false', dest=Args.WHITEN) 79 | parser.set_defaults(**{Args.WHITEN: DEFAULT_WHITEN}) 80 | args_dict = vars(parser.parse_args()) 81 | return args_dict 82 | 83 | 84 | def _read_soccernet_features( 85 | features_dir: Path, splits_dir: Path, 86 | extractor_type: str) -> np.ndarray: 87 | game_list = GamePathsReader.read_game_list_v2(splits_dir, SPLIT_KEY_TRAIN) 88 | feature_name = extractor_type_to_feature_name(extractor_type) 89 | return read_and_concatenate_features(features_dir, game_list, feature_name) 90 | 91 | 92 | def _create_pca_file_path( 93 | features_dir: Path, extractor_type: str, n_components: int, 94 | whiten: bool) -> str: 95 | return str(features_dir / 96 | _create_pca_file_name(extractor_type, n_components, whiten)) 97 | 98 | 99 | def _create_pca_file_name( 100 | extractor_type: str, n_components: int, whiten: bool) -> str: 101 | return f"pca_transform_{extractor_type}_{n_components}_whiten_{whiten}.pkl" 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | -------------------------------------------------------------------------------- /spivak/models/averaging_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import pickle 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | import numpy as np 10 | from scipy.special import logit, expit 11 | 12 | from spivak.data.dataset import VideoDatum 13 | from spivak.data.output_names import OUTPUT_DETECTION_SCORE 14 | from spivak.models.delta_dense_predictor import DeltaDensePredictor, clip_frames 15 | from spivak.models.dense_predictor import DensePredictor, \ 16 | create_detection_scores, OUTPUT_CONFIDENCE, OUTPUT_DELTA 17 | from spivak.models.non_maximum_suppression import FlexibleNonMaximumSuppression 18 | from spivak.models.predictor import PredictorInterface, VideoOutputs 19 | 20 | AveragingPredictor = Union[ 21 | "ConfidenceAveragingPredictor", "DeltaAveragingPredictor"] 22 | 23 | 24 | class ConfidenceAveragingPredictor(PredictorInterface): 25 | 26 | def __init__(self, weights: np.ndarray, use_logits: bool) -> None: 27 | self.weights = weights 28 | self._use_logits = use_logits 29 | 30 | def predict_video(self, video_datum: VideoDatum) -> VideoOutputs: 31 | if self._use_logits: 32 | confidence = expit(np.average( 33 | logit(video_datum.features), axis=2, weights=self.weights)) 34 | else: 35 | confidence = np.average( 36 | video_datum.features, axis=2, weights=self.weights) 37 | return { 38 | OUTPUT_CONFIDENCE: confidence, OUTPUT_DETECTION_SCORE: confidence} 39 | 40 | def predict_video_and_save( 41 | self, video_datum: VideoDatum, nms: FlexibleNonMaximumSuppression, 42 | base_path: Path) -> None: 43 | video_outputs = self.predict_video(video_datum) 44 | DensePredictor.save_predictions(video_outputs, nms, base_path) 45 | DensePredictor.save_labels(video_datum, base_path) 46 | 47 | def save_model(self, model_path: str) -> None: 48 | with open(model_path, "wb") as model_file: 49 | pickle.dump(self, model_file) 50 | 51 | def load_weights(self, weights_path: str) -> None: 52 | raise NotImplementedError() 53 | 54 | 55 | class DeltaAveragingPredictor(PredictorInterface): 56 | 57 | def __init__( 58 | self, weights: np.ndarray, confidence_dir: Path, 59 | use_arcs: bool, delta_radius: float) -> None: 60 | self.weights = weights 61 | self._confidence_dir = confidence_dir 62 | self._use_arcs = use_arcs 63 | self._delta_radius = delta_radius 64 | 65 | def predict_video(self, video_datum: VideoDatum) -> VideoOutputs: 66 | if self._use_arcs: 67 | delta = self._delta_radius * np.tanh(np.average( 68 | np.arctanh(video_datum.features / self._delta_radius), 69 | axis=2, weights=self.weights 70 | )) 71 | else: 72 | delta = np.average( 73 | video_datum.features, axis=2, weights=self.weights) 74 | confidence_path = DeltaDensePredictor.confidence_path( 75 | self._confidence_dir, video_datum.relative_path) 76 | confidence = np.load(str(confidence_path)) 77 | # If delta and confidence are generated using different types of 78 | # features, they might have small size differences. 79 | min_num_frames = min(delta.shape[0], confidence.shape[0]) 80 | confidence = clip_frames(confidence, min_num_frames) 81 | delta = clip_frames(delta, min_num_frames) 82 | video_outputs = {OUTPUT_CONFIDENCE: confidence, OUTPUT_DELTA: delta} 83 | video_outputs[OUTPUT_DETECTION_SCORE] = create_detection_scores( 84 | video_outputs, throw_out_delta=False) 85 | return video_outputs 86 | 87 | def predict_video_and_save( 88 | self, video_datum: VideoDatum, nms: FlexibleNonMaximumSuppression, 89 | base_path: Path) -> None: 90 | video_outputs = self.predict_video(video_datum) 91 | DensePredictor.save_predictions(video_outputs, nms, base_path) 92 | DensePredictor.save_labels(video_datum, base_path) 93 | 94 | def save_model(self, model_path: str) -> None: 95 | with open(model_path, "wb") as model_file: 96 | pickle.dump(self, model_file) 97 | 98 | def load_weights(self, weights_path: str) -> None: 99 | raise NotImplementedError() 100 | -------------------------------------------------------------------------------- /bin/predict_on_videos.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import logging 9 | from pathlib import Path 10 | from typing import Dict 11 | 12 | from spivak.application.test_utils import test, SharedArgs, \ 13 | translate_dataset_type_to_custom 14 | from spivak.data.dataset_splits import SPLIT_KEY_UNLABELED 15 | from spivak.data.video_io import list_video_paths 16 | from spivak.feature_extraction.extraction import extract_features_from_videos, \ 17 | EXTRACTOR_TYPE_RESNET_TF2, create_feature_extractor 18 | 19 | 20 | class Args: 21 | INPUT_VIDEOS_DIR = "input_dir" 22 | FEATURES_DIR = "features_dir" 23 | RESULTS_DIR = "results_dir" 24 | LABELS_DIR = "labels_dir" 25 | MODEL_PATH = "model" 26 | FEATURES_MODELS_DIR = "features_models_dir" 27 | FEATURES = "features" 28 | CONFIG_DIR = "config_dir" 29 | SPLITS_DIR = "splits_dir" 30 | 31 | 32 | def main() -> None: 33 | args = _get_command_line_arguments() 34 | logging.getLogger().setLevel(logging.DEBUG) 35 | input_dir = Path(args[Args.INPUT_VIDEOS_DIR]) 36 | if not input_dir.is_dir(): 37 | raise ValueError(f"Input directory failed is_dir(): {input_dir}") 38 | features_dir = Path(args[Args.FEATURES_DIR]) 39 | results_dir = Path(args[Args.RESULTS_DIR]) 40 | features_dir.mkdir(parents=True, exist_ok=True) 41 | results_dir.mkdir(parents=True, exist_ok=True) 42 | feature_extractor = create_feature_extractor( 43 | args[Args.FEATURES], Path(args[Args.FEATURES_MODELS_DIR])) 44 | video_paths = list_video_paths(input_dir) 45 | extract_features_from_videos(video_paths, features_dir, feature_extractor) 46 | # Set up prediction run by loading existing model arguments and 47 | # overwriting them for the current run. 48 | # TODO: allow this to work for "dense_delta" model by taking 49 | # two input models in MODEL_PATH and running inference twice in a row, 50 | # once for each model. For now, can just run it twice with the same 51 | # results folder (first with just the confidence model, then with the 52 | # dense_delta one). 53 | shared_args = SharedArgs.load(args[Args.MODEL_PATH]) 54 | shared_args.model = args[Args.MODEL_PATH] 55 | shared_args.results_dir = results_dir 56 | shared_args.features_dir = str(features_dir) 57 | shared_args.config_dir = args[Args.CONFIG_DIR] 58 | shared_args.dataset_type = translate_dataset_type_to_custom( 59 | shared_args.dataset_type) 60 | shared_args.labels_dir = args[Args.LABELS_DIR] 61 | shared_args.splits_dir = args[Args.SPLITS_DIR] 62 | shared_args.test_split = SPLIT_KEY_UNLABELED 63 | # Don't run evaluation, just prediction. 64 | shared_args.evaluate = 0 65 | # Run prediction code on the generated features. 66 | test(shared_args) 67 | 68 | 69 | def _get_command_line_arguments() -> Dict: 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument( 72 | "--" + Args.INPUT_VIDEOS_DIR, help='Input directory containing videos', 73 | required=True) 74 | parser.add_argument( 75 | "--" + Args.RESULTS_DIR, help="Output directory", required=True) 76 | parser.add_argument( 77 | "--" + Args.MODEL_PATH, help="Model directory", required=True) 78 | parser.add_argument( 79 | "--" + Args.FEATURES_MODELS_DIR, required=True, 80 | help="Directory containing models used for extracting video features") 81 | parser.add_argument( 82 | "--" + Args.FEATURES_DIR, required=True, 83 | help="Directory in which to store intermediate video features") 84 | parser.add_argument( 85 | "--" + Args.LABELS_DIR, required=False, 86 | help="Directory containing label files, if available") 87 | parser.add_argument( 88 | "--" + Args.CONFIG_DIR, type=str, required=True, 89 | help="Directory containing a set of config files",) 90 | parser.add_argument( 91 | "--" + Args.SPLITS_DIR, type=str, required=True, 92 | help="Directory for storing the generated split files") 93 | parser.add_argument( 94 | "--" + Args.FEATURES, required=False, 95 | help="What type of features to use", default=EXTRACTOR_TYPE_RESNET_TF2, 96 | choices=[EXTRACTOR_TYPE_RESNET_TF2]) 97 | args_dict = vars(parser.parse_args()) 98 | return args_dict 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /spivak/html_visualization/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from enum import Enum 6 | from math import ceil 7 | from pathlib import Path 8 | from typing import TextIO, List, Dict 9 | 10 | import numpy as np 11 | import plotly.express as px 12 | from pandas import DataFrame 13 | from plotly.graph_objs import Figure 14 | 15 | from spivak.html_visualization.result_data import COLUMN_CLASS, COLUMN_TIME 16 | 17 | SOURCE_FLOAT_LABEL = 5.0 18 | SOURCE_FLOAT_PREDICTION = 1.0 19 | # HTML and JS code templates. 20 | VIDEO_TEMPLATE = """""" 24 | PLOTLY_CLICK_SCRIPT_TEMPLATE = """ 25 | 40 | """ 41 | 42 | 43 | class CategorySettings: 44 | 45 | def __init__(self, category_order: Dict[str, List[str]], 46 | discrete_color_map: Dict[str, str]) -> None: 47 | self.category_order = category_order 48 | self.discrete_color_map = discrete_color_map 49 | 50 | 51 | class ColorMapChoice(Enum): 52 | 53 | PLOTLY = 0 54 | DARK24 = 1 55 | LIGHT24 = 2 56 | ALPHABET = 3 57 | 58 | 59 | def create_category_settings(categories: List[str]) -> CategorySettings: 60 | return create_custom_category_settings(categories, ColorMapChoice.DARK24) 61 | 62 | 63 | def create_custom_category_settings( 64 | categories: List[str], 65 | colormap_choice: ColorMapChoice) -> CategorySettings: 66 | plotly_colors = _colormap_from_choice(colormap_choice) 67 | if len(categories) > len(plotly_colors): 68 | print( 69 | f"There are not enough colors in the color sequence as to support " 70 | f"the number of categories. Will repeat colors. Number of " 71 | f"categories: {len(categories)}, number of colors: " 72 | f"{len(plotly_colors)}") 73 | num_repetitions = ceil(len(categories) / len(plotly_colors)) 74 | repeated_plotly_colors = num_repetitions * plotly_colors 75 | plotly_colors = repeated_plotly_colors[:len(categories)] 76 | category_order = {COLUMN_CLASS: categories} 77 | discrete_color_map = { 78 | category: plotly_colors[category_index] 79 | for category_index, category in enumerate(categories)} 80 | return CategorySettings(category_order, discrete_color_map) 81 | 82 | 83 | def add_video( 84 | html_file: TextIO, video_html_relative_path: Path, 85 | video_id: str) -> None: 86 | _add_video_to_file(html_file, video_html_relative_path, video_id) 87 | html_file.write('\n
\n') 88 | 89 | 90 | def add_click_code(html_file: TextIO, video_id: str) -> None: 91 | html_file.write(PLOTLY_CLICK_SCRIPT_TEMPLATE.format(video_id)) 92 | 93 | 94 | def adjust_subplot_xaxes(fig: Figure, n_categories: int) -> None: 95 | fig.update_xaxes(tickformat='%M:%S.%L', showticklabels=False) 96 | fig.update_xaxes(showticklabels=True, row=n_categories, col=1) 97 | 98 | 99 | def extract_locations( 100 | category_data_frame: DataFrame, column: str, y_value: float): 101 | x = category_data_frame[COLUMN_TIME][category_data_frame[column] == 1] 102 | y = y_value * np.ones(len(x)) 103 | return x, y 104 | 105 | 106 | def _add_video_to_file( 107 | html_file: TextIO, video_html_relative_path: Path, video_id: str) -> None: 108 | html_file.write(VIDEO_TEMPLATE.format(video_id, video_html_relative_path)) 109 | 110 | 111 | def _colormap_from_choice(colormap_choice: ColorMapChoice): 112 | if colormap_choice == ColorMapChoice.DARK24: 113 | return px.colors.qualitative.Dark24 114 | if colormap_choice == ColorMapChoice.LIGHT24: 115 | return px.colors.qualitative.Light24 116 | elif colormap_choice == ColorMapChoice.PLOTLY: 117 | return px.colors.qualitative.Plotly 118 | elif colormap_choice == ColorMapChoice.ALPHABET: 119 | return px.colors.qualitative.Alphabet 120 | else: 121 | raise ValueError(f"Unknown colormap choice: {colormap_choice}") 122 | -------------------------------------------------------------------------------- /spivak/feature_extraction/soccernet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | # 5 | # This file incorporates work covered by the following copyright and permission 6 | # notice: 7 | # Copyright (c) 2021 Silvio Giancola 8 | # Licensed under the terms of the MIT license. 9 | # You may obtain a copy of the MIT License at https://opensource.org/licenses/MIT 10 | 11 | # This file contains pieces of code taken from the following file. 12 | # https://github.com/SilvioGiancola/SoccerNetv2-DevKit/blob/20f2f74007c82b68a73c519dff852188df4a8b5a/Features/VideoFeatureExtractor.py 13 | # At Yahoo Inc., the original code was modified and new code was added. The 14 | # code now uses different versions of FrameCV and Frame (for decoding videos), 15 | # which are defined here in SoccerNetDataLoader.py. 16 | 17 | import logging 18 | import pickle 19 | import time 20 | from abc import ABCMeta, abstractmethod 21 | from pathlib import Path 22 | 23 | import numpy as np 24 | from tensorflow import keras 25 | from tensorflow.keras import Model 26 | from tensorflow.keras.applications.resnet import preprocess_input 27 | 28 | from spivak.feature_extraction.SoccerNetDataLoader import FrameCV, Frame 29 | 30 | 31 | class RawFeatureExtractorInterface(metaclass=ABCMeta): 32 | 33 | @abstractmethod 34 | def extract_features( 35 | self, video_path: str, game_start_time: int, game_end_time: int 36 | ) -> np.ndarray: 37 | pass 38 | 39 | 40 | class PCAInterface(metaclass=ABCMeta): 41 | 42 | @abstractmethod 43 | def transform(self, raw_features: np.ndarray) -> np.ndarray: 44 | pass 45 | 46 | 47 | class FeatureExtractorResNetTF2(RawFeatureExtractorInterface): 48 | 49 | def __init__( 50 | self, model_weights_path: str, grabber="opencv", fps=2.0, 51 | image_transform="crop") -> None: 52 | self.grabber = grabber 53 | self.fps = fps 54 | self.image_transform = image_transform 55 | base_model = keras.applications.resnet.ResNet152( 56 | include_top=True, weights=model_weights_path, 57 | input_tensor=None, input_shape=None, pooling=None, classes=1000) 58 | # define model with output after polling layer (dim=2048) 59 | self.model = Model( 60 | base_model.input, outputs=[base_model.get_layer("avg_pool").output]) 61 | self.model.trainable = False 62 | 63 | def extract_features( 64 | self, video_path: str, game_start_time: int, game_end_time: int 65 | ) -> np.ndarray: 66 | start = None 67 | video_duration = None 68 | if game_start_time: 69 | start = game_start_time 70 | if game_end_time: 71 | video_duration = game_end_time - game_start_time 72 | if self.grabber == "skvideo": 73 | video_loader = Frame( 74 | video_path, FPS=self.fps, transform=self.image_transform, 75 | start=start, duration=video_duration) 76 | elif self.grabber == "opencv": 77 | video_loader = FrameCV( 78 | video_path, FPS=self.fps, transform=self.image_transform, 79 | start=start, duration=video_duration) 80 | else: 81 | raise ValueError(f"Unknown frame grabber: {self.grabber}") 82 | frames = preprocess_input(video_loader.frames) 83 | if video_duration is None: 84 | video_duration = video_loader.time_second 85 | logging.info( 86 | f"frames {frames.shape}, fps={frames.shape[0] / video_duration}") 87 | # predict the features from the frames (adjust batch size for smaller 88 | # GPU) 89 | prediction_start_time = time.time() 90 | features = self.model.predict(frames, batch_size=64, verbose=1) 91 | prediction_time = time.time() - prediction_start_time 92 | logging.info(f"feature model prediction time: {prediction_time}") 93 | logging.info( 94 | f"features {features.shape}, fps=" 95 | f"{features.shape[0] / video_duration}") 96 | return features 97 | 98 | 99 | class SoccerNetPCATransformer(PCAInterface): 100 | 101 | def __init__(self, pca_path: Path, scalar_path: Path) -> None: 102 | with pca_path.open("rb") as pca_file: 103 | self.pca = pickle.load(pca_file) 104 | with scalar_path.open("rb") as scalar_file: 105 | self.average = pickle.load(scalar_file) 106 | 107 | def transform(self, features: np.ndarray) -> np.ndarray: 108 | features = features - self.average 109 | features = self.pca.transform(features) 110 | return features 111 | -------------------------------------------------------------------------------- /spivak/evaluation/segmentation_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from io import StringIO 6 | from typing import Dict, List 7 | 8 | import numpy as np 9 | 10 | from spivak.data.label_map import LabelMap 11 | from spivak.data.soccernet_label_io import \ 12 | segmentation_targets_from_change_labels 13 | from spivak.evaluation.segmentation_evaluation_old import \ 14 | run_segmentation_evaluation_old, calculate_f1_scores 15 | from spivak.evaluation.task_evaluation import TaskEvaluation 16 | 17 | 18 | class SegmentationEvaluation(TaskEvaluation): 19 | 20 | METRIC_MEAN_IOU = "mean_iou" 21 | METRIC_IOU = "iou" 22 | 23 | def __init__( 24 | self, mean_iou: float, per_class_iou: np.ndarray, 25 | label_map: LabelMap) -> None: 26 | self.mean_iou = mean_iou 27 | self.per_class_iou = per_class_iou 28 | self.label_map = label_map 29 | 30 | def scalars_for_logging(self) -> Dict[str, float]: 31 | scalars = { 32 | f"{SegmentationEvaluation.METRIC_IOU}_" 33 | f"{self.label_map.int_to_label[c]}": class_iou 34 | for c, class_iou in enumerate(self.per_class_iou) 35 | } 36 | scalars[SegmentationEvaluation.METRIC_MEAN_IOU] = self.mean_iou 37 | return scalars 38 | 39 | def summary(self) -> str: 40 | with StringIO() as str_io: 41 | self._write_summary(str_io) 42 | summary = str_io.getvalue() 43 | return summary 44 | 45 | def _write_summary(self, str_io: StringIO) -> None: 46 | str_io.write("Segmentation evaluation:\n") 47 | str_io.write( 48 | f"{SegmentationEvaluation.METRIC_MEAN_IOU}: {self.mean_iou}\n") 49 | str_io.write("\nIoU per class:\n") 50 | for class_index, class_iou in enumerate(self.per_class_iou): 51 | class_name = self.label_map.int_to_label[class_index] 52 | str_io.write(f"{class_name}: {class_iou}\n") 53 | 54 | 55 | def create_segmentation_evaluation( 56 | all_segmentations: List[np.ndarray], all_labels: List[np.ndarray], 57 | label_map: LabelMap) -> SegmentationEvaluation: 58 | # f1_manual ignores the last class, not sure if we want to do that, 59 | # so can maybe just ignore it. 60 | f1_macro, f1_micro, f1_manual, mean_iou, per_class_iou = \ 61 | _run_segmentation_evaluation( 62 | all_segmentations, all_labels, label_map.num_classes()) 63 | return SegmentationEvaluation(mean_iou, per_class_iou, label_map) 64 | 65 | 66 | def create_segmentation_evaluation_old( 67 | all_segmentations: List[np.ndarray], label_map: LabelMap, 68 | list_games, labels_dir, frame_rate) -> SegmentationEvaluation: 69 | # Replicates the SoccerNet code evaluation as of August 23, 2022. That code 70 | # had some small problems, which are noted inside individual comments in 71 | # segmentation_evaluation_old.py. 72 | f1_macro, f1_micro, f1_manual, mean_iou, per_class_iou = \ 73 | run_segmentation_evaluation_old( 74 | all_segmentations, label_map.num_classes(), list_games, labels_dir, 75 | frame_rate) 76 | return SegmentationEvaluation(mean_iou, per_class_iou, label_map) 77 | 78 | 79 | def _run_segmentation_evaluation(all_predictions, all_labels, num_classes): 80 | intersection_counts_per_class = np.zeros(num_classes, dtype=np.float32) 81 | union_counts_per_class = np.zeros(num_classes, dtype=np.float32) 82 | all_targets = [ 83 | segmentation_targets_from_change_labels(labels) 84 | for labels in all_labels] 85 | for (video_predictions, video_targets) in zip(all_predictions, all_targets): 86 | # Convert from one-hot to integers. 87 | video_predictions_integers = video_predictions.argmax(axis=1) 88 | video_targets_integers = video_targets.argmax(axis=1) 89 | for class_index in range(num_classes): 90 | target_mask = (video_targets_integers == class_index) 91 | prediction_mask = (video_predictions_integers == class_index) 92 | intersection_count = np.sum( 93 | np.logical_and(target_mask, prediction_mask), dtype=np.float32) 94 | union_count = np.sum( 95 | np.logical_or(target_mask, prediction_mask), dtype=np.float32) 96 | intersection_counts_per_class[class_index] += intersection_count 97 | union_counts_per_class[class_index] += union_count 98 | per_class_iou = np.divide( 99 | intersection_counts_per_class, union_counts_per_class) 100 | mean_iou = float(np.mean(per_class_iou)) 101 | f1_macro, f1_micro, f1_manual = calculate_f1_scores( 102 | all_targets, all_predictions, num_classes) 103 | return f1_macro, f1_micro, f1_manual, mean_iou, per_class_iou 104 | -------------------------------------------------------------------------------- /spivak/data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from abc import ABCMeta, abstractmethod 6 | from enum import IntEnum 7 | from pathlib import Path 8 | from typing import Tuple, List, Optional, Dict 9 | 10 | import numpy as np 11 | 12 | INDEX_LABELS = 0 13 | INDEX_VALID = 1 14 | 15 | 16 | class Task(IntEnum): 17 | SPOTTING = 0 18 | SEGMENTATION = 1 19 | 20 | 21 | TASK_NAMES = { 22 | Task.SPOTTING: "spotting", 23 | Task.SEGMENTATION: "segmentation", 24 | } 25 | 26 | # Using only Dict and Tuple here in order to be able to manipulate 27 | # LabelsFromTaskDict with TensorFlow's tensorflow.data library. 28 | LabelsAndValid = Tuple[np.ndarray, bool] 29 | LabelsFromTaskDict = Dict[Task, LabelsAndValid] 30 | InputShape = Tuple[int, int, int] 31 | 32 | 33 | class Dataset: 34 | 35 | def __init__( 36 | self, video_data: List["VideoDatum"], input_shape: InputShape, 37 | num_classes_from_task: Dict[Task, int]) -> None: 38 | self.video_data = video_data 39 | self.input_shape = input_shape 40 | self.tasks = list(num_classes_from_task.keys()) 41 | self.num_classes_from_task = num_classes_from_task 42 | self.num_features = input_shape[1] 43 | self.num_videos = len(video_data) 44 | 45 | 46 | class VideoDatum(metaclass=ABCMeta): 47 | 48 | @abstractmethod 49 | def labels(self, task: Task) -> Optional[np.ndarray]: 50 | pass 51 | 52 | @abstractmethod 53 | def valid_labels(self, task: Task) -> bool: 54 | pass 55 | 56 | @property 57 | @abstractmethod 58 | def labels_from_task(self) -> LabelsFromTaskDict: 59 | pass 60 | 61 | @property 62 | @abstractmethod 63 | def num_classes_from_task(self) -> Dict[Task, int]: 64 | pass 65 | 66 | @property 67 | @abstractmethod 68 | def features(self) -> np.ndarray: 69 | pass 70 | 71 | @property 72 | @abstractmethod 73 | def relative_path(self) -> Path: 74 | pass 75 | 76 | @property 77 | @abstractmethod 78 | def num_features(self) -> int: 79 | pass 80 | 81 | @property 82 | @abstractmethod 83 | def num_frames(self) -> int: 84 | pass 85 | 86 | @property 87 | @abstractmethod 88 | def tasks(self) -> List[Task]: 89 | pass 90 | 91 | 92 | class DefaultVideoDatum(VideoDatum): 93 | 94 | """This caches the labels, but not the features. This should work well 95 | for most use-cases.""" 96 | 97 | def __init__( 98 | self, features_path: Path, relative_path: Path, 99 | labels_from_task: LabelsFromTaskDict, num_frames: int) -> None: 100 | self._features_path = features_path 101 | self._relative_path = relative_path 102 | self._labels_from_task = labels_from_task 103 | self._num_classes_from_task = { 104 | task: task_labels[INDEX_LABELS].shape[1] 105 | for task, task_labels in labels_from_task.items()} 106 | self._num_frames = num_frames 107 | 108 | def labels(self, task: Task) -> Optional[np.ndarray]: 109 | if task not in self._labels_from_task: 110 | return None 111 | return self._labels_from_task[task][INDEX_LABELS] 112 | 113 | def valid_labels(self, task: Task) -> bool: 114 | if task not in self._labels_from_task: 115 | return False 116 | return self._labels_from_task[task][INDEX_VALID] 117 | 118 | @property 119 | def labels_from_task(self) -> LabelsFromTaskDict: 120 | return self._labels_from_task 121 | 122 | @property 123 | def num_classes_from_task(self) -> Dict[Task, int]: 124 | return self._num_classes_from_task 125 | 126 | @property 127 | def features(self) -> np.ndarray: 128 | return np.load(str(self._features_path)) 129 | 130 | @property 131 | def relative_path(self) -> Path: 132 | return self._relative_path 133 | 134 | @property 135 | def num_features(self) -> int: 136 | return _read_num_features(self._features_path) 137 | 138 | @property 139 | def num_frames(self) -> int: 140 | return self._num_frames 141 | 142 | @property 143 | def tasks(self) -> List[Task]: 144 | return list(self._labels_from_task.keys()) 145 | 146 | 147 | def read_num_frames(features_path: Path) -> int: 148 | shape = read_numpy_shape(features_path) 149 | return shape[0] 150 | 151 | 152 | def read_numpy_shape(features_path: Path) -> List[int]: 153 | with features_path.open('rb') as features_file: 154 | file_version = np.lib.format.read_magic(features_file) 155 | assert file_version == (1, 0) 156 | shape, _, _ = np.lib.format.read_array_header_1_0(features_file) 157 | return shape 158 | 159 | 160 | def _read_num_features(features_path: Path) -> int: 161 | shape = read_numpy_shape(features_path) 162 | return shape[1] 163 | -------------------------------------------------------------------------------- /spivak/data/video_chunk_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from typing import Tuple, List 6 | 7 | import numpy as np 8 | 9 | 10 | class VideoChunkIteratorProvider: 11 | """This is to be used to help run inference over test videos and 12 | aggregate the results.""" 13 | 14 | def __init__(self, chunk_frames: int, num_border_frames: int) -> None: 15 | self._chunk_frames = chunk_frames 16 | self._num_border_frames = num_border_frames 17 | 18 | def provide(self, video_features: np.ndarray) -> "VideoChunkIterator": 19 | return VideoChunkIterator( 20 | video_features, self._chunk_frames, self._num_border_frames) 21 | 22 | 23 | class VideoChunkIterator: 24 | 25 | def __init__( 26 | self, video_features: np.ndarray, chunk_frames: int, 27 | num_border_frames: int) -> None: 28 | self.chunk_features_expanded = None 29 | self.valid_chunk_size = None 30 | self._output_start = None 31 | self._output_end = None 32 | self._result_start = None 33 | self._result_end = None 34 | self._is_last = False 35 | self._chunk_start = 0 36 | self._chunk_frames = chunk_frames 37 | self._num_border_frames = num_border_frames 38 | self._video_features_expanded = np.expand_dims(video_features, axis=-1) 39 | self._num_frames = self._video_features_expanded.shape[0] 40 | 41 | def prepare_input_batch(self) -> Tuple[np.ndarray, List[int]]: 42 | input_chunk_batch_list = [] 43 | valid_chunk_sizes = [] 44 | while self.has_next(): 45 | self.next() 46 | input_chunk_batch_list.append(self.chunk_features_expanded) 47 | valid_chunk_sizes.append(self.valid_chunk_size) 48 | return np.concatenate(input_chunk_batch_list), valid_chunk_sizes 49 | 50 | def accumulate_chunk_outputs( 51 | self, accumulated_output: np.ndarray, 52 | output_chunks: List[np.ndarray]) -> None: 53 | chunk_index = 0 54 | while self.has_next(): 55 | self.next() 56 | self.accumulate(accumulated_output, output_chunks[chunk_index]) 57 | chunk_index += 1 58 | 59 | def has_next(self) -> bool: 60 | return not self._is_last 61 | 62 | def next(self) -> None: 63 | # Get the outputs for this chunk and store the results. 64 | self.valid_chunk_size = min( 65 | self._chunk_frames, 66 | self._video_features_expanded.shape[0] - self._chunk_start) 67 | chunk_end = self._chunk_start + self.valid_chunk_size 68 | if self.valid_chunk_size == self._chunk_frames: 69 | # This might be faster than creating the np.zeros as below, 70 | # but not sure. 71 | chunk_features_expanded = \ 72 | self._video_features_expanded[self._chunk_start:chunk_end] 73 | else: 74 | # in this case, data_expanded is not as big as chunk_size, so we add 75 | # extra zero-padding before passing it through the network. 76 | chunk_features_expanded = np.zeros(( 77 | self._chunk_frames, self._video_features_expanded.shape[1], 78 | self._video_features_expanded.shape[2])) 79 | chunk_features_expanded[0:self.valid_chunk_size] = \ 80 | self._video_features_expanded[self._chunk_start:chunk_end] 81 | # Prepare the batch made of one chunk for the network 82 | self.chunk_features_expanded = np.expand_dims( 83 | chunk_features_expanded, axis=0) 84 | # Figure out start and end indexes. 85 | is_first = (self._chunk_start == 0) 86 | if is_first: 87 | self._output_start = 0 88 | else: 89 | self._output_start = self._num_border_frames 90 | self._result_start = self._chunk_start + self._output_start 91 | self._is_last = ( 92 | self._chunk_start >= self._num_frames - self._chunk_frames) 93 | if self._is_last: 94 | self._output_end = self.valid_chunk_size 95 | else: 96 | self._output_end = (self.valid_chunk_size - 97 | self._num_border_frames) 98 | self._result_end = self._chunk_start + self._output_end 99 | # Update the start index for the next iteration 100 | self._chunk_start += self._chunk_frames - 2 * self._num_border_frames 101 | if self._chunk_start > self._num_frames - self._chunk_frames: 102 | self._chunk_start = self._num_frames - self._chunk_frames 103 | 104 | def accumulate( 105 | self, accumulated_output: np.ndarray, 106 | output_chunk: np.ndarray) -> None: 107 | result_start = self._result_start 108 | result_end = self._result_end 109 | output_start = self._output_start 110 | output_end = self._output_end 111 | accumulated_output[result_start:result_end] = \ 112 | output_chunk[output_start:output_end] 113 | -------------------------------------------------------------------------------- /spivak/application/command_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import os 6 | import subprocess 7 | from typing import Dict, List, Union, Optional 8 | 9 | from spivak.application.argument_parser import DETECTOR_DENSE, \ 10 | DETECTOR_DENSE_DELTA 11 | 12 | # Scripts 13 | SCRIPT_TRAIN = "./bin/train.py" 14 | SCRIPT_TEST = "./bin/test.py" 15 | SCRIPT_TRANSFORM = "./bin/transform_features.py" 16 | SCRIPT_CREATE_NORMALIZER = "./bin/create_normalizer.py" 17 | SCRIPT_CREATE_FEATURES_FROM_RESULTS = "./bin/create_features_from_results.py" 18 | SCRIPT_CREATE_AVERAGING_PREDICTOR = "./bin/create_averaging_predictor.py" 19 | # System executables 20 | EXECUTABLE_ZIP = "zip" 21 | # Protocol names, used only for creating directory names 22 | SPOTTING_TEST = "spotting_test" 23 | SPOTTING_CHALLENGE = "spotting_challenge" 24 | SPOTTING_CHALLENGE_VALIDATED = "spotting_challenge_validated" 25 | # Config folders 26 | CONFIG_DIR_CHALLENGE_CONFIDENCE = "soccernet_challenge_confidence" 27 | CONFIG_DIR_CHALLENGE_DELTA = "soccernet_challenge_delta" 28 | CONFIG_DIR_CHALLENGE_DELTA_SOFT_NMS = "soccernet_challenge_delta_soft_nms" 29 | CONFIG_DIR_CONFIDENCE = "soccernet_confidence" 30 | CONFIG_DIR_DELTA = "soccernet_delta" 31 | CONFIG_DIR_DELTA_SOFT_NMS = "soccernet_delta_soft_nms" 32 | # Types for holding command arguments 33 | DictArgs = Dict[str, str] 34 | ListArgs = List[str] 35 | Args = Union[ListArgs, DictArgs] 36 | 37 | 38 | class Command: 39 | 40 | def __init__( 41 | self, description: str, executable: str, arguments: Args, 42 | cwd: Optional[str] = None, 43 | env_vars: Optional[Dict[str, str]] = None) -> None: 44 | self.description = description 45 | self.executable = executable 46 | self.arguments = arguments 47 | self.cwd = cwd 48 | # Make sure self.env_vars is a dictionary, to simplify the code later. 49 | if env_vars is None: 50 | self.env_vars = {} 51 | else: 52 | self.env_vars = env_vars 53 | 54 | def run(self) -> subprocess.CompletedProcess: 55 | print(f"Going to run the following command: {self.description}") 56 | print(self.command_line_str()) 57 | env = {**os.environ, **self.env_vars} 58 | return subprocess.run(self._as_list(), cwd=self.cwd, env=env) 59 | 60 | def command_line_str(self) -> str: 61 | return (self.environment_variables_str() + 62 | " " + 63 | " ".join(self._as_list())) 64 | 65 | def environment_variables_str(self) -> str: 66 | variables_list = [ 67 | f"{key}={value}" 68 | for key, value in self.env_vars.items() 69 | ] 70 | return " ".join(variables_list) 71 | 72 | def __str__(self) -> str: 73 | complete_str = f"Command: {self.description}\n{self.command_line_str()}" 74 | if self.cwd: 75 | complete_str = complete_str + f"\nCWD: {self.cwd}" 76 | return complete_str 77 | 78 | def _as_list(self) -> List[str]: 79 | if isinstance(self.arguments, dict): 80 | arguments_list = [] 81 | for key in self.arguments: 82 | value = self.arguments[key] 83 | if " " in value: 84 | values = value.split(" ") 85 | else: 86 | values = [value] 87 | new_list = [key] + values 88 | arguments_list.extend(new_list) 89 | elif isinstance(self.arguments, list): 90 | arguments_list = self.arguments 91 | else: 92 | raise ValueError( 93 | f"Command instance arguments is of unexpected type: " 94 | f"{type(self.arguments)}") 95 | return [self.executable] + arguments_list 96 | 97 | 98 | def detector_args(detector: str) -> DictArgs: 99 | detector_arguments = {"-dc": detector} 100 | if detector == DETECTOR_DENSE: 101 | detector_arguments.update({"-cw": "1.0"}) 102 | elif detector == DETECTOR_DENSE_DELTA: 103 | detector_arguments.update({"-dw": "1.0"}) 104 | return detector_arguments 105 | 106 | 107 | def create_name( 108 | parameters: DictArgs, extra_name: str, model_name: str, 109 | feature_name: str, protocol: str) -> str: 110 | parameters_string = _parameters_to_string(parameters) 111 | return (f"{protocol}_{feature_name}_{model_name}_{extra_name}" 112 | f"_{parameters_string}") 113 | 114 | 115 | def print_command_list(commands: List[Command]) -> None: 116 | for command in commands: 117 | print(command, end="\n\n") 118 | 119 | 120 | def _parameters_to_string(parameters: DictArgs) -> str: 121 | parameters_strings = [ 122 | _parameter_to_string(key, parameters[key]) for key in parameters] 123 | return "_".join(parameters_strings) 124 | 125 | 126 | def _parameter_to_string(key: str, value: str) -> str: 127 | clean_key = key.replace("-", "") 128 | if _has_digit(value): 129 | separator = "" 130 | else: 131 | separator = "_" 132 | return f"{clean_key}{separator}{value}" 133 | 134 | 135 | def _has_digit(value: str) -> bool: 136 | return any(char.isdigit() for char in value) 137 | -------------------------------------------------------------------------------- /spivak/models/assembly/bottom_up.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from abc import ABCMeta, abstractmethod 6 | from typing import List 7 | 8 | from tensorflow import Tensor 9 | from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D 10 | 11 | from spivak.models.assembly.convolution_stacks import StridedBlockInterface, \ 12 | ConvolutionStackInterface 13 | 14 | POOLING_MAX = "max" 15 | POOLING_AVERAGE = "average" 16 | 17 | 18 | class BottomUpStackInterface(metaclass=ABCMeta): 19 | 20 | @abstractmethod 21 | def downsample_and_convolve( 22 | self, bottom_up: Tensor, num_filters_in: int, num_filters_out: int, 23 | layer_index: int) -> Tensor: 24 | pass 25 | 26 | @abstractmethod 27 | def convolve( 28 | self, bottom_up: Tensor, num_filters: int, 29 | layer_index: int) -> Tensor: 30 | pass 31 | 32 | 33 | class BottomUpLayer: 34 | 35 | def __init__(self, tensor: Tensor, num_channels: int) -> None: 36 | self.tensor = tensor 37 | self.num_channels = num_channels 38 | 39 | 40 | class PoolingBottomUpStack(BottomUpStackInterface): 41 | 42 | def __init__( 43 | self, pooling: str, 44 | convolution_stack: ConvolutionStackInterface) -> None: 45 | self._pooling = pooling 46 | self._convolution_stack = convolution_stack 47 | self._name = "bu" 48 | 49 | def convolve( 50 | self, bottom_up: Tensor, num_filters: int, 51 | layer_index: int) -> Tensor: 52 | return self._convolution_stack.convolve( 53 | bottom_up, num_filters, layer_index, self._name) 54 | 55 | def downsample_and_convolve( 56 | self, bottom_up: Tensor, num_filters_in: int, num_filters_out: int, 57 | layer_index: int) -> Tensor: 58 | if self._pooling == POOLING_MAX: 59 | pooling = MaxPooling2D( 60 | (2, 1), name=f"{self._name}_layer{layer_index}_max_pooling") 61 | elif self._pooling == POOLING_AVERAGE: 62 | pooling = AveragePooling2D( 63 | (2, 1), name=f"{self._name}_layer{layer_index}_average_pooling") 64 | else: 65 | raise ValueError(f"Unrecognized pooling: {self._pooling}") 66 | pooled = pooling(bottom_up) 67 | return self._convolution_stack.convolve( 68 | pooled, num_filters_out, layer_index, self._name) 69 | 70 | 71 | class StridedBottomUpStack(BottomUpStackInterface): 72 | 73 | def __init__( 74 | self, strided_block: StridedBlockInterface, 75 | layer_num_blocks: List[int], strided_reduction: bool) -> None: 76 | self._strided_block = strided_block 77 | self._layer_num_blocks = layer_num_blocks 78 | self._strided_reduction = strided_reduction 79 | self._name = "bu" 80 | 81 | def convolve( 82 | self, bottom_up: Tensor, num_filters: int, 83 | layer_index: int) -> Tensor: 84 | for block_index in range(self._layer_num_blocks[layer_index]): 85 | bottom_up = self._strided_block.convolve( 86 | bottom_up, num_filters, 87 | name=f"{self._name}_layer{layer_index}_block{block_index}") 88 | return bottom_up 89 | 90 | def downsample_and_convolve( 91 | self, bottom_up: Tensor, num_filters_in: int, num_filters_out: int, 92 | layer_index: int) -> Tensor: 93 | num_blocks_in_layer = self._layer_num_blocks[layer_index] 94 | if self._strided_reduction or num_blocks_in_layer < 2: 95 | stride_filters = num_filters_out 96 | else: 97 | stride_filters = num_filters_in 98 | bottom_up = self._strided_block.strided_convolve( 99 | bottom_up, stride_filters, 100 | name=f"{self._name}_layer{layer_index}_block0") 101 | for block_index in range(1, num_blocks_in_layer): 102 | bottom_up = self._strided_block.convolve( 103 | bottom_up, num_filters_out, 104 | name=f"{self._name}_layer{layer_index}_block{block_index}") 105 | return bottom_up 106 | 107 | 108 | def create_bottom_up_layers( 109 | input_mlp_out: Tensor, num_layers: int, base_num_filters: int, 110 | max_num_filters: int, bottom_up_stack: BottomUpStackInterface 111 | ) -> List[BottomUpLayer]: 112 | # VGG-16 applied dropout of 0.5 to their last two layers. U-net paper did 113 | # something similar, only applying dropout at the end of their bottom-up 114 | # layers. 115 | # https://arxiv.org/pdf/1409.1556.pdf 116 | # https://arxiv.org/pdf/1505.04597.pdf 117 | bottom_up_layers = [] 118 | # Start with just a convolution stack and add that to the layers. 119 | num_filters_out = min(max_num_filters, base_num_filters) 120 | x = bottom_up_stack.convolve(input_mlp_out, num_filters_out, 0) 121 | bottom_up_layers.append(BottomUpLayer(x, num_filters_out)) 122 | # Now, for each layer, add a stack that downsamples then convolves. 123 | for layer_index in range(1, num_layers): 124 | num_filters_in = num_filters_out 125 | num_filters_out = min( 126 | max_num_filters, 2 ** layer_index * base_num_filters) 127 | x = bottom_up_stack.downsample_and_convolve( 128 | x, num_filters_in, num_filters_out, layer_index) 129 | bottom_up_layers.append(BottomUpLayer(x, num_filters_out)) 130 | return bottom_up_layers 131 | -------------------------------------------------------------------------------- /bin/create_averaging_predictor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import logging 8 | from pathlib import Path 9 | from typing import List 10 | 11 | import numpy as np 12 | 13 | from spivak.application.argument_parser import get_args, \ 14 | DETECTOR_AVERAGING_CONFIDENCE, DETECTOR_AVERAGING_DELTA, dir_str_to_path, \ 15 | SharedArgs 16 | from spivak.application.dataset_creation import create_label_maps, \ 17 | create_soccernet_video_data_reader 18 | from spivak.application.model_creation import create_flexible_nms, \ 19 | create_delta_radius 20 | from spivak.application.validation import create_all_video_outputs, \ 21 | create_detections_and_targets 22 | from spivak.data.dataset import Task, VideoDatum, read_numpy_shape 23 | from spivak.data.dataset_splits import SPLIT_KEY_VALIDATION 24 | from spivak.data.label_map import LabelMap 25 | from spivak.evaluation.spotting_evaluation import run_spotting_evaluation, \ 26 | TolerancesConfig, read_tolerances_config 27 | from spivak.models.averaging_predictor import ConfidenceAveragingPredictor, \ 28 | DeltaAveragingPredictor, AveragingPredictor 29 | from spivak.models.non_maximum_suppression import FlexibleNonMaximumSuppression 30 | 31 | DELTA_WEIGHT_RANGE = [0.0, 0.3, 0.5, 0.7, 0.8, 0.9, 0.95, 0.97, 0.98, 0.99, 1.0] 32 | CONFIDENCE_LOGIT_WEIGHT_RANGE = [ 33 | 0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 34 | 0.65, 0.675, 0.7, 0.725, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0] 35 | CONFIDENCE_USE_LOGITS = True 36 | DELTA_USE_ARCS = False 37 | 38 | 39 | def main() -> None: 40 | logging.getLogger().setLevel(logging.DEBUG) 41 | args = get_args() 42 | label_maps = create_label_maps(args) 43 | soccernet_video_data_reader = create_soccernet_video_data_reader( 44 | args, label_maps) 45 | video_data = soccernet_video_data_reader.read(SPLIT_KEY_VALIDATION) 46 | num_features = _read_num_features(video_data[0]._features_path) 47 | predictor = _create_averaging_predictor(args, num_features) 48 | _optimize_predictor(args, predictor, video_data, label_maps[Task.SPOTTING]) 49 | predictor.save_model(args.model) 50 | 51 | 52 | def _read_num_features(features_path: Path) -> int: 53 | shape = read_numpy_shape(features_path) 54 | return shape[2] 55 | 56 | 57 | def _create_averaging_predictor( 58 | args: SharedArgs, num_features: int) -> AveragingPredictor: 59 | weights = np.zeros(num_features) 60 | if args.detector == DETECTOR_AVERAGING_CONFIDENCE: 61 | predictor = ConfidenceAveragingPredictor(weights, CONFIDENCE_USE_LOGITS) 62 | elif args.detector == DETECTOR_AVERAGING_DELTA: 63 | predictor = DeltaAveragingPredictor( 64 | weights, dir_str_to_path(args.results_dir), DELTA_USE_ARCS, 65 | create_delta_radius(args)) 66 | else: 67 | raise ValueError(f"Unknown averaging predictor type: {args.detector}") 68 | return predictor 69 | 70 | 71 | def _optimize_predictor( 72 | args: SharedArgs, predictor: AveragingPredictor, 73 | video_data: List[VideoDatum], label_map: LabelMap) -> None: 74 | weight_range = _create_weight_range(args.detector) 75 | flexible_nms = create_flexible_nms(args, label_map) 76 | # Prepare deltas for evaluation 77 | config_dir = dir_str_to_path(args.config_dir) 78 | tolerances_config = read_tolerances_config(config_dir) 79 | logging.info(f"Computing metrics for weights in {weight_range}") 80 | metrics = [ 81 | _compute_main_metric( 82 | args, video_data, predictor, weight, flexible_nms, 83 | tolerances_config, label_map) 84 | for weight in weight_range 85 | ] 86 | logging.info(f"Metrics found: {metrics}") 87 | max_metric_index = np.argmax(metrics) 88 | best_weight = weight_range[max_metric_index] 89 | predictor.weights = _weights_from_weight(best_weight) 90 | 91 | 92 | def _compute_main_metric( 93 | args: SharedArgs, video_data: List[VideoDatum], 94 | predictor: AveragingPredictor, weight: float, 95 | flexible_nms: FlexibleNonMaximumSuppression, 96 | tolerances_config: TolerancesConfig, label_map: LabelMap) -> float: 97 | predictor.weights = _weights_from_weight(weight) 98 | all_video_outputs = create_all_video_outputs(video_data, predictor) 99 | detections, targets = create_detections_and_targets( 100 | video_data, all_video_outputs, flexible_nms) 101 | spotting_evaluation = run_spotting_evaluation( 102 | detections, targets, tolerances_config, args.frame_rate, 103 | label_map.num_classes(), bool(args.prune_classes), 104 | create_confusion_data_frame=False, label_map=label_map) 105 | main_metric = spotting_evaluation.average_map_dict[ 106 | spotting_evaluation.main_tolerances_name] 107 | logging.info( 108 | f"Got {spotting_evaluation.main_tolerances_name} average-mAP: " 109 | f"{main_metric} for weight {weight}.") 110 | return main_metric 111 | 112 | 113 | def _create_weight_range(detector: str) -> List[float]: 114 | if detector == DETECTOR_AVERAGING_CONFIDENCE and CONFIDENCE_USE_LOGITS: 115 | return CONFIDENCE_LOGIT_WEIGHT_RANGE 116 | elif detector == DETECTOR_AVERAGING_DELTA: 117 | return DELTA_WEIGHT_RANGE 118 | else: 119 | raise ValueError(f"Unknown averaging predictor type: {detector}") 120 | 121 | 122 | def _weights_from_weight(weight: float) -> np.ndarray: 123 | return np.array([weight, 1.0 - weight]) 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /spivak/video_visualization/recognition_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import math 6 | from typing import List, Optional 7 | from typing import Tuple 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | from spivak.data.label_map import LabelMap 13 | 14 | LABEL_SEPARATOR = ";" 15 | LABEL_TEXT_RELATIVE_GAP_X = 30 16 | BOTTOM_LABEL_TEXT_RELATIVE_GAP_Y = 20 17 | TARGET_FONT_SCALE = 1.0 18 | SMALL_TARGET_FONT_SCALE = 0.6 19 | BOLD_FONT_RELATIVE_THICKNESS = 2.0 20 | # Avoid changing the font, as it will change other constants below 21 | # (unfortunately, the font size changes with the font choice). 22 | FONT = cv2.FONT_HERSHEY_SIMPLEX 23 | # Don't make the font too small, so the number of pixels in the text isn't 24 | # so small that it becomes unreadable. 25 | FONT_SCALE_MIN_ABSOLUTE = 0.3 26 | # Don't make the font too small, relative to the image size. 27 | FONT_SCALE_MIN_RELATIVE = 5e-4 28 | # Don't make the font too large, relative to the image size. 29 | FONT_SCALE_MAX_RELATIVE = 1e-3 30 | CV_COLOR_TEXT = (255, 255, 255) 31 | 32 | OpenCVTextSizes = List[Tuple[Tuple[int, int], int]] 33 | 34 | 35 | class Label: 36 | 37 | def __init__(self, text: str, bold: bool) -> None: 38 | self.text = text 39 | self.bold = bold 40 | 41 | 42 | class FrameRecognizedActionsView: 43 | 44 | def __init__(self, time_in_seconds: float, scores: np.ndarray, 45 | label_map: LabelMap) -> None: 46 | self.time_in_seconds = time_in_seconds 47 | self.scores = scores 48 | self.label_map = label_map 49 | 50 | 51 | def get_label( 52 | score: float, recognition_threshold: float, class_name: str) -> Label: 53 | return Label(create_label_text(score, class_name), 54 | score > recognition_threshold) 55 | 56 | 57 | def get_multi_labels( 58 | frames_actions_view: FrameRecognizedActionsView, 59 | recognition_threshold: float) -> List[Label]: 60 | return [get_label( 61 | score, recognition_threshold, 62 | frames_actions_view.label_map.int_to_label[index]) 63 | for index, score in enumerate(frames_actions_view.scores)] 64 | 65 | 66 | def get_standard_font_thickness(font_scale: float) -> int: 67 | return max(1, int(font_scale)) 68 | 69 | 70 | def get_bold_font_thickness( 71 | standard_font_thickness: int, font_scale: float) -> int: 72 | bold_font_thickness = int(font_scale * BOLD_FONT_RELATIVE_THICKNESS) 73 | if bold_font_thickness == standard_font_thickness: 74 | return standard_font_thickness + 1 75 | return bold_font_thickness 76 | 77 | 78 | def get_text_size( 79 | label: Label, font_scale: float, standard_font_thickness: int, 80 | bold_font_thickness: int) -> OpenCVTextSizes: 81 | if label.bold: 82 | effective_thickness = bold_font_thickness 83 | else: 84 | effective_thickness = standard_font_thickness 85 | return cv2.getTextSize(label.text, FONT, font_scale, effective_thickness) 86 | 87 | 88 | def cv2_draw_labels( 89 | np_image: np.ndarray, frame_time: float, 90 | labels: Optional[List[Label]]) -> None: 91 | if not labels: 92 | return 93 | image_height, image_width, _ = np_image.shape 94 | font_scale = compute_font_scale(image_height, image_width) 95 | standard_font_thickness = get_standard_font_thickness(font_scale) 96 | bold_font_thickness = get_bold_font_thickness( 97 | standard_font_thickness, font_scale) 98 | text_sizes = [ 99 | get_text_size(label, font_scale, standard_font_thickness, 100 | bold_font_thickness) for label in labels] 101 | baselines = np.array([t[1] for t in text_sizes]) 102 | max_baseline = max(baselines) 103 | location_y = math.ceil( 104 | image_height - max_baseline - font_scale * 105 | BOTTOM_LABEL_TEXT_RELATIVE_GAP_Y) 106 | location_xs = compute_labels_xs(font_scale, text_sizes) 107 | for i, (label, location_x) in enumerate(zip(labels, location_xs)): 108 | if label.bold: 109 | font_thickness = bold_font_thickness 110 | else: 111 | font_thickness = standard_font_thickness 112 | text = label.text 113 | # Add the separator to all labels except the last one. 114 | if i != len(labels) - 1: 115 | text += LABEL_SEPARATOR 116 | cv2.putText( 117 | np_image, text, (location_x, location_y), FONT, font_scale, 118 | CV_COLOR_TEXT, font_thickness) 119 | 120 | 121 | def compute_labels_xs( 122 | font_scale: float, text_sizes: List[OpenCVTextSizes]) -> List[int]: 123 | label_widths = np.array([t[0][0] for t in text_sizes]) 124 | relative_shifts = np.insert(label_widths[:-1], 0, 0) 125 | relative_shifts_with_gaps = ( 126 | relative_shifts + font_scale * LABEL_TEXT_RELATIVE_GAP_X) 127 | label_shifts = np.cumsum(relative_shifts_with_gaps) 128 | return [int(np.rint(shift)) for shift in label_shifts] 129 | 130 | 131 | def create_label_text(score: float, class_name: str) -> str: 132 | return "[{:.2f}]: {:s}".format(score, class_name) 133 | 134 | 135 | def compute_font_scale(image_height: int, image_width: int) -> float: 136 | if image_height > image_width: 137 | # We have a tall video, so let's make the font smaller so that more 138 | # text will fit in the bottom. 139 | target_font_scale = SMALL_TARGET_FONT_SCALE 140 | else: 141 | target_font_scale = TARGET_FONT_SCALE 142 | return _limit_font_scale(target_font_scale, image_height) 143 | 144 | 145 | def _limit_font_scale(target_font_scale: float, image_height: int) -> float: 146 | min_font_scale = max( 147 | FONT_SCALE_MIN_RELATIVE * image_height, FONT_SCALE_MIN_ABSOLUTE) 148 | max_font_scale = FONT_SCALE_MAX_RELATIVE * image_height 149 | return min(max(min_font_scale, target_font_scale), max_font_scale) 150 | -------------------------------------------------------------------------------- /spivak/feature_extraction/extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import logging 6 | import pickle 7 | from pathlib import Path 8 | from typing import List 9 | 10 | import numpy as np 11 | from sklearn.decomposition import IncrementalPCA 12 | 13 | from spivak.feature_extraction.soccernet_v2 import FeatureExtractorResNetTF2, \ 14 | SoccerNetPCATransformer, PCAInterface, RawFeatureExtractorInterface 15 | 16 | # Feature extraction constants. 17 | DEFAULT_TIME_STRIDE = 0.5 18 | GAME_START_TIME = 0 19 | GAME_END_TIME = 0 20 | EXTRACTOR_TYPE_RESNET_TF2 = "ResNet_TF2" 21 | # SOCCERNET_FEATURE_NAME_AAA have to match convention in SoccerNet codebase. 22 | SOCCERNET_FEATURE_NAME_RESNET_TF2 = "ResNET_TF2" 23 | MODEL_WEIGHTS_RESNET_TF2 = "resnet152_weights_tf_dim_ordering_tf_kernels.h5" 24 | PCA_RESNET_TF2 = "pca_512_TF2.pkl" 25 | PCA_SCALER_RESNET_TF2 = "average_512_TF2.pkl" 26 | 27 | 28 | class VideoInfo: 29 | 30 | def __init__( 31 | self, video_path: Path, features_dir: Path, 32 | relative_features_dir: Path) -> None: 33 | self.video_path = video_path 34 | self.features_dir = features_dir 35 | self.relative_features_dir = relative_features_dir 36 | 37 | 38 | class PCATransformer(PCAInterface): 39 | 40 | def __init__(self, file_path: Path) -> None: 41 | with file_path.open("rb") as pickle_file: 42 | self.projection: IncrementalPCA = pickle.load(pickle_file) 43 | 44 | def transform(self, raw_features: np.ndarray) -> np.ndarray: 45 | return self.projection.transform(raw_features) 46 | 47 | 48 | class FeatureExtractor: 49 | 50 | def __init__( 51 | self, raw_feature_extractor: RawFeatureExtractorInterface, 52 | feature_file_prefix: str, pca_transformer: PCAInterface) -> None: 53 | self.raw_feature_extractor = raw_feature_extractor 54 | self.feature_file_prefix = feature_file_prefix 55 | self.pca_transformer = pca_transformer 56 | 57 | def make_features(self, video_info: VideoInfo) -> None: 58 | features_file_name = f"{self.feature_file_prefix}_PCA512.npy" 59 | features_path = video_info.features_dir / features_file_name 60 | if features_path.exists(): 61 | logging.warning( 62 | f"*** Skipping feature creation (already exists) for: " 63 | f"{features_path}") 64 | return 65 | logging.info(f"Creating feature file: {features_path}") 66 | features_array = self._create_features(video_info) 67 | np.save(str(features_path), features_array) 68 | 69 | def _create_features(self, video_info: VideoInfo) -> np.ndarray: 70 | raw_features = self._get_raw_features(video_info) 71 | return self.pca_transformer.transform(raw_features) 72 | 73 | def _get_raw_features(self, video_info: VideoInfo) -> np.ndarray: 74 | raw_file_name = f"{self.feature_file_prefix}.npy" 75 | raw_path = video_info.features_dir / raw_file_name 76 | if raw_path.exists(): 77 | logging.warning( 78 | f"*** Skipping RAW features creation (already exists) for: " 79 | f"{raw_path}") 80 | return np.load(str(raw_path)) 81 | logging.info(f"Creating RAW feature file: {raw_path}") 82 | raw_array = self.raw_feature_extractor.extract_features( 83 | str(video_info.video_path), GAME_START_TIME, GAME_END_TIME) 84 | np_raw_features = np.array(raw_array) 85 | np.save(str(raw_path), np_raw_features) 86 | return np_raw_features 87 | 88 | 89 | def extract_features_from_videos( 90 | video_paths: List[Path], features_dir: Path, 91 | feature_extractor: FeatureExtractor) -> None: 92 | video_infos = _create_video_infos(video_paths, features_dir) 93 | _make_feature_directories(video_infos) 94 | _make_features(video_infos, feature_extractor) 95 | 96 | 97 | def create_feature_extractor( 98 | extractor_type: str, model_dir: Path) -> FeatureExtractor: 99 | if extractor_type == EXTRACTOR_TYPE_RESNET_TF2: 100 | fps = 1.0 / DEFAULT_TIME_STRIDE 101 | resnet_tf2_path = str(model_dir / MODEL_WEIGHTS_RESNET_TF2) 102 | soccernet_feature_extractor = FeatureExtractorResNetTF2( 103 | resnet_tf2_path, fps=fps) 104 | pca_path = model_dir / PCA_RESNET_TF2 105 | pca_scaler_path = model_dir / PCA_SCALER_RESNET_TF2 106 | pca_transformer = SoccerNetPCATransformer(pca_path, pca_scaler_path) 107 | else: 108 | raise ValueError(f"Invalid value for extractor type: {extractor_type}") 109 | feature_file_prefix = extractor_type_to_feature_name(extractor_type) 110 | return FeatureExtractor( 111 | soccernet_feature_extractor, feature_file_prefix, pca_transformer) 112 | 113 | 114 | def extractor_type_to_feature_name(extractor_type: str) -> str: 115 | if extractor_type == EXTRACTOR_TYPE_RESNET_TF2: 116 | soccernet_feature_name = SOCCERNET_FEATURE_NAME_RESNET_TF2 117 | else: 118 | raise ValueError(f"Invalid value for extractor type: {extractor_type}") 119 | return soccernet_feature_name 120 | 121 | 122 | def _create_video_infos( 123 | video_paths: List[Path], features_dir: Path) -> List[VideoInfo]: 124 | return [_create_video_info(video_path, features_dir) for video_path in 125 | video_paths] 126 | 127 | 128 | def _create_video_info(video_path: Path, features_dir: Path) -> VideoInfo: 129 | video_features_dir = features_dir / video_path.stem 130 | relative_features_dir = video_features_dir.relative_to(features_dir) 131 | return VideoInfo(video_path, video_features_dir, relative_features_dir) 132 | 133 | 134 | def _make_feature_directories(video_infos: List[VideoInfo]) -> None: 135 | for video_info in video_infos: 136 | video_info.features_dir.mkdir(exist_ok=True) 137 | 138 | 139 | def _make_features( 140 | video_infos: List[VideoInfo], 141 | feature_extractor: FeatureExtractor) -> None: 142 | for video_info in video_infos: 143 | feature_extractor.make_features(video_info) 144 | -------------------------------------------------------------------------------- /spivak/models/non_maximum_suppression.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import csv 6 | import math 7 | from abc import abstractmethod 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import numpy as np 12 | 13 | from spivak.data.label_map import LabelMap 14 | 15 | MIN_DETECTION_SCORE_SUPPRESS = 0.0 16 | # This was tweaked by looking at the results and the size of the resulting 17 | # JSON files. In the experiments, 1e-3 was already enough to give good final 18 | # scores, so 1e-5 is a bit of overkill, but JSON file sizes are still 19 | # manageable (around 2 or 3MB). 20 | MIN_DETECTION_SCORE_LINEAR = 1e-5 21 | 22 | 23 | class FlexibleNonMaximumSuppression: 24 | 25 | NMS_COLUMN_WINDOW = "window" 26 | NMS_COLUMN_LABEL = "label" 27 | 28 | def __init__( 29 | self, nms_on: bool, class_windows: Optional[np.ndarray], 30 | score_decay: "ScoreDecayInterface") -> None: 31 | """class_windows are not in seconds, but in frames.""" 32 | self._nms_on = nms_on 33 | self._class_windows = class_windows 34 | self._score_decay = score_decay 35 | 36 | def maybe_apply(self, detection_scores: np.ndarray) -> np.ndarray: 37 | if not self._nms_on: 38 | return detection_scores 39 | return _flexible_non_maximum_suppression( 40 | detection_scores, self._class_windows, self._score_decay) 41 | 42 | @staticmethod 43 | def read_nms_windows(windows_path: Path, label_map: LabelMap) -> np.ndarray: 44 | num_classes = label_map.num_classes() 45 | class_windows_in_seconds = np.empty(num_classes) 46 | # Set everything to nan, so we can later make sure all values were set. 47 | class_windows_in_seconds[:] = np.nan 48 | with windows_path.open("r") as csv_file: 49 | reader = csv.DictReader(csv_file) 50 | for row in reader: 51 | label = row[FlexibleNonMaximumSuppression.NMS_COLUMN_LABEL] 52 | # Just ignore the label if it's not in the label map. 53 | if label in label_map.label_to_int: 54 | label_int = label_map.label_to_int[label] 55 | class_windows_in_seconds[label_int] = row[ 56 | FlexibleNonMaximumSuppression.NMS_COLUMN_WINDOW] 57 | # Make sure all values have been filled in. 58 | assert not np.any(np.isnan(class_windows_in_seconds)) 59 | return class_windows_in_seconds 60 | 61 | 62 | class ScoreDecayInterface: 63 | 64 | @abstractmethod 65 | def decay(self, scores: np.ndarray, max_index: int, window: float) -> None: 66 | """Decay (or suppress) scores in a window around the max value.""" 67 | pass 68 | 69 | @property 70 | @abstractmethod 71 | def min_detection_score(self) -> float: 72 | pass 73 | 74 | 75 | class ScoreDecaySuppress(ScoreDecayInterface): 76 | 77 | def decay(self, scores: np.ndarray, max_index: int, window: float) -> None: 78 | """This is standard non-maximum suppression, where a score gets 79 | completely suppressed if it is in the neighborhood of the max value.""" 80 | int_radius = math.floor(window / 2.0) 81 | start = max(max_index - int_radius, 0) 82 | end = min(max_index + int_radius + 1, scores.shape[0]) 83 | scores[start:end] = -1 84 | 85 | @property 86 | def min_detection_score(self) -> float: 87 | return MIN_DETECTION_SCORE_SUPPRESS 88 | 89 | 90 | class ScoreDecayLinear(ScoreDecayInterface): 91 | 92 | def __init__(self, min_weight: float, window_expansion: float) -> None: 93 | self._min_weight = min_weight 94 | self._window_expansion = window_expansion 95 | 96 | def decay(self, scores: np.ndarray, max_index: int, window: float) -> None: 97 | # So that the effect of the linear decay is more comparable to 98 | # that of the regular NMS (suppression done by ScoreDecaySuppress), 99 | # we expand the window used for the linear decay here. 100 | expanded_window = self._window_expansion * window 101 | radius = expanded_window / 2.0 102 | radius_ceil = math.ceil(radius) 103 | start = max(max_index - radius_ceil, 0) 104 | end = min(max_index + radius_ceil + 1, scores.shape[0]) 105 | frame_range = np.arange(start, end) 106 | weights = self._min_weight + (1.0 - self._min_weight) * np.abs( 107 | (frame_range - max_index) / radius) 108 | clipped_weights = np.clip(weights, 0.0, 1.0) 109 | scores[frame_range] *= clipped_weights 110 | # Remove the max_index. 111 | scores[max_index] = -1 112 | 113 | @property 114 | def min_detection_score(self) -> float: 115 | return MIN_DETECTION_SCORE_LINEAR 116 | 117 | 118 | def _flexible_non_maximum_suppression( 119 | scores: np.ndarray, class_windows: np.ndarray, 120 | score_decay: ScoreDecayInterface) -> np.ndarray: 121 | # Apply nms separately for each class. 122 | nms_scores_list = [ 123 | _single_class_nms( 124 | scores[:, class_index], class_window, score_decay) 125 | for class_index, class_window in enumerate(class_windows) 126 | ] 127 | return np.column_stack(nms_scores_list) 128 | 129 | 130 | def _single_class_nms( 131 | class_scores: np.ndarray, class_window: float, 132 | score_decay: ScoreDecayInterface) -> np.ndarray: 133 | class_scores_nms = np.zeros(class_scores.shape) - 1 134 | class_scores_tmp = np.copy(class_scores) 135 | max_index = int(np.argmax(class_scores_tmp)) 136 | max_score = class_scores_tmp[max_index] 137 | min_detection_score = score_decay.min_detection_score 138 | while max_score >= min_detection_score: 139 | # Copy the highest value over to the final result. 140 | class_scores_nms[max_index] = max_score 141 | # Suppress the scores in class_scores_tmp 142 | score_decay.decay(class_scores_tmp, max_index, class_window) 143 | # Find the maximum from the remaining values. 144 | max_index = np.argmax(class_scores_tmp) 145 | max_score = class_scores_tmp[max_index] 146 | return class_scores_nms 147 | -------------------------------------------------------------------------------- /spivak/data/dataset_splits.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import csv 6 | import logging 7 | import random 8 | from collections import defaultdict 9 | from pathlib import Path 10 | from typing import List, Dict 11 | 12 | # For reading/writing the splits 13 | from spivak.data.dataset import Task 14 | 15 | FIELD_VIDEO_NAME = "video_name" 16 | FIELD_SPLIT_KEY = "split_key" 17 | FIELD_NAMES = [FIELD_VIDEO_NAME, FIELD_SPLIT_KEY] 18 | # Fractions used for splitting the labeled data into train, validation, 19 | # and test sets. 20 | FRACTION_TRAIN = 0.8 21 | FRACTION_VALIDATION = 0.1 22 | # Split types 23 | SPLIT_KEY_TEST = "test" 24 | SPLIT_KEY_TRAIN = "train" 25 | SPLIT_KEY_VALIDATION = "validation" 26 | # For videos that have no ground truth. Used just for visualizing results. 27 | SPLIT_KEY_UNLABELED = "unlabeled" 28 | ALL_SPLIT_KEYS = [ 29 | SPLIT_KEY_TRAIN, SPLIT_KEY_VALIDATION, SPLIT_KEY_TEST, SPLIT_KEY_UNLABELED] 30 | # FRACTION_TEST is implicitly equal to 1 - FRACTION_TRAIN - FRACTION_VALIDATION 31 | JSON_EXTENSION = ".json" 32 | 33 | Split = List[str] 34 | Splits = Dict[str, Split] 35 | 36 | 37 | class SplitPathsProvider: 38 | 39 | def __init__( 40 | self, features_paths: List[Path], labels_dir_dict: Dict[Task, Path], 41 | splits: Splits) -> None: 42 | self._features_paths = features_paths 43 | self._labels_dir_dict = labels_dir_dict 44 | self._splits = splits 45 | 46 | def provide(self, split_key: str): 47 | split = self._splits[split_key] 48 | split_features_paths = [ 49 | features_path for features_path in self._features_paths 50 | if name_from_path(features_path) in split] 51 | if not split_features_paths: 52 | raise ValueError("No feature files from split found in dataset dir") 53 | split_labels_path_dicts = _create_labels_path_dicts( 54 | split_features_paths, self._labels_dir_dict) 55 | return split_features_paths, split_labels_path_dicts 56 | 57 | 58 | def create_features_paths(features_dir: Path, feature_name: str) -> List[Path]: 59 | if not features_dir.is_dir(): 60 | raise ValueError(f"Not a valid directory for features: {features_dir}") 61 | features_paths = sorted(features_dir.glob(f"**/*{feature_name}.npy")) 62 | if not features_paths: 63 | raise ValueError( 64 | f"No features of type {feature_name} found in " 65 | f"features dir {features_dir}") 66 | return features_paths 67 | 68 | 69 | def load_or_create_splits( 70 | features_paths: List[Path], labels_dir_dict: Dict[Task, Path], 71 | splits_path: Path) -> Splits: 72 | if not splits_path.exists(): 73 | _create_splits(features_paths, labels_dir_dict, splits_path) 74 | return _read_splits(splits_path) 75 | 76 | 77 | def name_from_path(features_path: Path) -> str: 78 | return features_path.parent.stem 79 | 80 | 81 | def _create_splits( 82 | features_paths: List[Path], labels_dir_dict: Dict[Task, Path], 83 | splits_path: Path) -> None: 84 | labels_path_dicts = _create_labels_path_dicts( 85 | features_paths, labels_dir_dict) 86 | labels_path_exists = [ 87 | _any_labels_exist(labels_path_dict) 88 | for labels_path_dict in labels_path_dicts] 89 | all_names = [ 90 | name_from_path(features_path) for features_path in features_paths] 91 | unlabeled_split = [ 92 | name for name, labels_path_exists in zip(all_names, labels_path_exists) 93 | if not labels_path_exists] 94 | labeled_names = [ 95 | name for name, labels_path_exists in zip(all_names, labels_path_exists) 96 | if labels_path_exists] 97 | n_labeled = len(labeled_names) 98 | # Reset random seed for consistency. 99 | random.seed() 100 | shuffled_labeled_names = random.sample(labeled_names, n_labeled) 101 | n_train = round(FRACTION_TRAIN * n_labeled) 102 | n_validation = round(FRACTION_VALIDATION * n_labeled) 103 | validation_end = n_train + n_validation 104 | train_split = sorted(shuffled_labeled_names[:n_train]) 105 | validation_split = sorted(shuffled_labeled_names[n_train:validation_end]) 106 | test_split = sorted(shuffled_labeled_names[validation_end:]) 107 | splits = { 108 | SPLIT_KEY_TRAIN: train_split, SPLIT_KEY_VALIDATION: validation_split, 109 | SPLIT_KEY_TEST: test_split, SPLIT_KEY_UNLABELED: unlabeled_split} 110 | _write_splits(splits_path, splits) 111 | 112 | 113 | def _create_labels_path_dicts( 114 | features_paths: List[Path], labels_dir_dict: Dict[Task, Path] 115 | ) -> List[Dict[Task, Path]]: 116 | labels_path_dicts = [ 117 | _get_labels_path_dict(labels_dir_dict, features_path) 118 | for features_path in features_paths] 119 | return labels_path_dicts 120 | 121 | 122 | def _read_splits(splits_path: Path) -> Splits: 123 | logging.debug(f"Reading dataset splits file at {splits_path}") 124 | splits = defaultdict(list) 125 | with splits_path.open("r") as splits_file: 126 | reader = csv.DictReader(splits_file) 127 | for row in reader: 128 | splits[row[FIELD_SPLIT_KEY]].append(row[FIELD_VIDEO_NAME]) 129 | return splits 130 | 131 | 132 | def _write_splits(splits_path: Path, splits: Splits) -> None: 133 | logging.warning(f"Creating dataset splits file at {splits_path}") 134 | with splits_path.open("w") as splits_file: 135 | writer = csv.DictWriter(splits_file, fieldnames=FIELD_NAMES) 136 | writer.writeheader() 137 | for split_key in splits: 138 | split = splits[split_key] 139 | for video_name in split: 140 | writer.writerow({ 141 | FIELD_VIDEO_NAME: video_name, FIELD_SPLIT_KEY: split_key}) 142 | 143 | 144 | def _get_labels_path_dict( 145 | labels_dir_dict: Dict[Task, Path], features_path: Path 146 | ) -> Dict[Task, Path]: 147 | return { 148 | task: labels_dir / (name_from_path(features_path) + JSON_EXTENSION) 149 | for task, labels_dir in labels_dir_dict.items()} 150 | 151 | 152 | def _any_labels_exist(labels_path_dict: Dict[Task, Path]) -> bool: 153 | for labels_path in labels_path_dict.values(): 154 | if labels_path.exists(): 155 | return True 156 | return False 157 | -------------------------------------------------------------------------------- /spivak/html_visualization/segmentation_visualization.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | from typing import TextIO, Dict 6 | 7 | import numpy as np 8 | import pandas 9 | from pandas import DataFrame 10 | from plotly import express as px 11 | from plotly.graph_objs import Figure 12 | from plotly.subplots import make_subplots 13 | 14 | from spivak.html_visualization.result_data import COLUMN_CLASS, \ 15 | COLUMN_SOURCE_FLOAT, COLUMN_TIME, COLUMN_SOURCE, \ 16 | COLUMN_SEGMENTATION_LABEL, COLUMN_SEGMENTATION_SCORE 17 | from spivak.html_visualization.utils import CategorySettings, \ 18 | adjust_subplot_xaxes, extract_locations 19 | 20 | SEGMENTATION_HEIGHT = 350 21 | SEGMENTATION_TITLE = "Segmentation prediction and labels" 22 | SEGMENTATION_PREDICTION_TICK_LABEL = "Segmentation prediction" 23 | SUBPLOT_VERTICAL_SPACING = 0.004 24 | SEGMENTATION_SCORES_HEIGHT = 1000 25 | SEGMENTATION_SCORES_TITLE = "Segmentation scores with predictions and labels" 26 | LABEL_SUFFIX = "_t" 27 | PREDICTION_SUFFIX = "_p" 28 | SCORES_RANGE_EXTRA = [0.0, 1.5] 29 | Y_LABELS = 1.35 30 | Y_PREDICTIONS = 1.15 31 | SOURCE_FLOAT_PREDICTION = 1.0 32 | SOURCE_FLOAT_LABEL = 3.0 33 | 34 | 35 | def add_segmentation_graph( 36 | html_file: TextIO, data_frame: DataFrame, 37 | category_settings: CategorySettings) -> None: 38 | fig = _plot_segmentation_and_labels( 39 | data_frame, category_settings, SEGMENTATION_HEIGHT) 40 | fig.update_layout( 41 | title_text=SEGMENTATION_TITLE, legend=dict(itemsizing="constant")) 42 | time_range = [data_frame[COLUMN_TIME].head(1).item(), 43 | data_frame[COLUMN_TIME].tail(1).item()] 44 | fig.update_xaxes(tickformat='%M:%S.%L', range=time_range) 45 | y_tick_values = [SOURCE_FLOAT_PREDICTION, SOURCE_FLOAT_LABEL] 46 | y_range = [min(y_tick_values) - 3.0, max(y_tick_values) + 3.0] 47 | y_tick_texts = [ 48 | SEGMENTATION_PREDICTION_TICK_LABEL, COLUMN_SEGMENTATION_LABEL] 49 | fig.update_yaxes( 50 | range=y_range, tickvals=y_tick_values, ticktext=y_tick_texts, 51 | title_text=COLUMN_SOURCE) 52 | fig.write_html(html_file) 53 | 54 | 55 | def add_segmentation_scores_graphs( 56 | html_file: TextIO, data_frame: DataFrame, 57 | category_settings: CategorySettings) -> None: 58 | n_categories = len(category_settings.discrete_color_map) 59 | predictions_data_frame = _max_prediction_per_time_instant(data_frame) 60 | fig = make_subplots( 61 | rows=n_categories, cols=1, shared_xaxes=True, 62 | vertical_spacing=SUBPLOT_VERTICAL_SPACING) 63 | for category_index, category in enumerate( 64 | category_settings.category_order[COLUMN_CLASS]): 65 | row = category_index + 1 66 | category_color = category_settings.discrete_color_map[category] 67 | category_data_frame = data_frame[data_frame[COLUMN_CLASS] == category] 68 | category_predictions_data_frame = predictions_data_frame[ 69 | predictions_data_frame[COLUMN_CLASS] == category] 70 | _fig_add_segmentation_scores( 71 | fig, category_data_frame, category, category_color, row) 72 | _fig_add_predictions( 73 | fig, category_predictions_data_frame, category + PREDICTION_SUFFIX, 74 | category_color, row) 75 | _fig_add_labels( 76 | fig, category_data_frame, category + LABEL_SUFFIX, "black", row) 77 | fig.update_yaxes( 78 | range=SCORES_RANGE_EXTRA, showticklabels=False, visible=True) 79 | adjust_subplot_xaxes(fig, n_categories) 80 | fig.update_layout( 81 | title_text=SEGMENTATION_SCORES_TITLE, height=SEGMENTATION_SCORES_HEIGHT, 82 | legend=dict(itemsizing="constant")) 83 | fig.write_html(html_file) 84 | 85 | 86 | def _plot_segmentation_and_labels( 87 | data_frame: DataFrame, category_settings: CategorySettings, 88 | height: int) -> Figure: 89 | # Get predictions from data frame 90 | predictions_data_frame = _max_prediction_per_time_instant(data_frame) 91 | predictions_data_frame[COLUMN_SOURCE_FLOAT] = SOURCE_FLOAT_PREDICTION 92 | # Get labels from data frame 93 | label_idx = data_frame[COLUMN_SEGMENTATION_LABEL] == 1 94 | labels_data_frame = data_frame.loc[label_idx].copy() 95 | labels_data_frame[COLUMN_SOURCE_FLOAT] = SOURCE_FLOAT_LABEL 96 | # Concatenate the segmentation predictions with the labels, then plot. 97 | predictions_and_labels_data_frame = pandas.concat( 98 | [predictions_data_frame, labels_data_frame]) 99 | fig = px.scatter( 100 | predictions_and_labels_data_frame, x=COLUMN_TIME, 101 | y=COLUMN_SOURCE_FLOAT, color=COLUMN_CLASS, height=height, 102 | category_orders=category_settings.category_order, 103 | color_discrete_map=category_settings.discrete_color_map) 104 | fig.update_traces( 105 | marker=dict(symbol="line-ns-open", size=35, line=dict(width=1)), 106 | selector=dict(mode="markers")) 107 | return fig 108 | 109 | 110 | def _max_prediction_per_time_instant(data_frame: DataFrame) -> DataFrame: 111 | max_score_per_frame_idx = ( 112 | data_frame.groupby([COLUMN_TIME])[COLUMN_SEGMENTATION_SCORE].idxmax()) 113 | return data_frame.loc[max_score_per_frame_idx] 114 | 115 | 116 | def _fig_add_segmentation_scores( 117 | fig: Figure, category_data_frame: DataFrame, category: str, 118 | category_color: str, row: int) -> None: 119 | line_dict = dict(color=category_color) 120 | fig.add_scatter( 121 | x=category_data_frame[COLUMN_TIME], 122 | y=category_data_frame[COLUMN_SEGMENTATION_SCORE], name=category, 123 | line=line_dict, col=1, row=row) 124 | 125 | 126 | def _fig_add_labels( 127 | fig: Figure, category_data_frame: DataFrame, name: str, color: str, 128 | row: int) -> None: 129 | x, y = extract_locations( 130 | category_data_frame, COLUMN_SEGMENTATION_LABEL, Y_LABELS) 131 | label_marker = _create_label_marker(color) 132 | fig.add_scatter( 133 | x=x, y=y, col=1, row=row, name=name, mode="markers", 134 | marker=label_marker) 135 | 136 | 137 | def _fig_add_predictions( 138 | fig: Figure, category_predictions_data_frame: DataFrame, name: str, 139 | color: str, row: int) -> None: 140 | x = category_predictions_data_frame[COLUMN_TIME] 141 | y = Y_PREDICTIONS * np.ones(len(x)) 142 | prediction_marker = _create_prediction_marker(color) 143 | fig.add_scatter( 144 | x=x, y=y, col=1, row=row, name=name, mode="markers", 145 | marker=prediction_marker) 146 | 147 | 148 | def _create_label_marker(color: str) -> Dict: 149 | return dict( 150 | size=2, color=color, line_color=color, symbol="circle", line_width=0) 151 | 152 | 153 | def _create_prediction_marker(color: str) -> Dict: 154 | return dict( 155 | size=2, color=color, line_color=color, symbol="circle", line_width=0) 156 | -------------------------------------------------------------------------------- /Code_of_Conduct.md: -------------------------------------------------------------------------------- 1 | # Yahoo Inc Open Source Code of Conduct 2 | 3 | ## Summary 4 | This Code of Conduct is our way to encourage good behavior and discourage bad behavior in our open source projects. We invite participation from many people to bring different perspectives to our projects. We will do our part to foster a welcoming and professional environment free of harassment. We expect participants to communicate professionally and thoughtfully during their involvement with this project. 5 | 6 | Participants may lose their good standing by engaging in misconduct. For example: insulting, threatening, or conveying unwelcome sexual content. We ask participants who observe conduct issues to report the incident directly to the project's Response Team at opensource-conduct@yahooinc.com. Yahoo will assign a respondent to address the issue. We may remove harassers from this project. 7 | 8 | This code does not replace the terms of service or acceptable use policies of the websites used to support this project. We acknowledge that participants may be subject to additional conduct terms based on their employment which may govern their online expressions. 9 | 10 | ## Details 11 | This Code of Conduct makes our expectations of participants in this community explicit. 12 | * We forbid harassment and abusive speech within this community. 13 | * We request participants to report misconduct to the project’s Response Team. 14 | * We urge participants to refrain from using discussion forums to play out a fight. 15 | 16 | ### Expected Behaviors 17 | We expect participants in this community to conduct themselves professionally. Since our primary mode of communication is text on an online forum (e.g. issues, pull requests, comments, emails, or chats) devoid of vocal tone, gestures, or other context that is often vital to understanding, it is important that participants are attentive to their interaction style. 18 | 19 | * **Assume positive intent.** We ask community members to assume positive intent on the part of other people’s communications. We may disagree on details, but we expect all suggestions to be supportive of the community goals. 20 | * **Respect participants.** We expect occasional disagreements. Open Source projects are learning experiences. Ask, explore, challenge, and then _respectfully_ state if you agree or disagree. If your idea is rejected, be more persuasive not bitter. 21 | * **Welcoming to new members.** New members bring new perspectives. Some ask questions that have been addressed before. _Kindly_ point to existing discussions. Everyone is new to every project once. 22 | * **Be kind to beginners.** Beginners use open source projects to get experience. They might not be talented coders yet, and projects should not accept poor quality code. But we were all beginners once, and we need to engage kindly. 23 | * **Consider your impact on others.** Your work will be used by others, and you depend on the work of others. We expect community members to be considerate and establish a balance their self-interest with communal interest. 24 | * **Use words carefully.** We may not understand intent when you say something ironic. Often, people will misinterpret sarcasm in online communications. We ask community members to communicate plainly. 25 | * **Leave with class.** When you wish to resign from participating in this project for any reason, you are free to fork the code and create a competitive project. Open Source explicitly allows this. Your exit should not be dramatic or bitter. 26 | 27 | ### Unacceptable Behaviors 28 | Participants remain in good standing when they do not engage in misconduct or harassment (some examples follow). We do not list all forms of harassment, nor imply some forms of harassment are not worthy of action. Any participant who *feels* harassed or *observes* harassment, should report the incident to the Response Team. 29 | * **Don't be a bigot.** Calling out project members by their identity or background in a negative or insulting manner. This includes, but is not limited to, slurs or insinuations related to protected or suspect classes e.g. race, color, citizenship, national origin, political belief, religion, sexual orientation, gender identity and expression, age, size, culture, ethnicity, genetic features, language, profession, national minority status, mental or physical ability. 30 | * **Don't insult.** Insulting remarks about a person’s lifestyle practices. 31 | * **Don't dox.** Revealing private information about other participants without explicit permission. 32 | * **Don't intimidate.** Threats of violence or intimidation of any project member. 33 | * **Don't creep.** Unwanted sexual attention or content unsuited for the subject of this project. 34 | * **Don't inflame.** We ask that victim of harassment not address their grievances in the public forum, as this often intensifies the problem. Report it, and let us address it off-line. 35 | * **Don't disrupt.** Sustained disruptions in a discussion. 36 | 37 | ### Reporting Issues 38 | If you experience or witness misconduct, or have any other concerns about the conduct of members of this project, please report it by contacting our Response Team at opensource-conduct@yahooinc.com who will handle your report with discretion. Your report should include: 39 | * Your preferred contact information. We cannot process anonymous reports. 40 | * Names (real or usernames) of those involved in the incident. 41 | * Your account of what occurred, and if the incident is ongoing. Please provide links to or transcripts of the publicly available records (e.g. a mailing list archive or a public IRC logger), so that we can review it. 42 | * Any additional information that may be helpful to achieve resolution. 43 | 44 | After filing a report, a representative will contact you directly to review the incident and ask additional questions. If a member of the Yahoo Response Team is named in an incident report, that member will be recused from handling your incident. If the complaint originates from a member of the Response Team, it will be addressed by a different member of the Response Team. We will consider reports to be confidential for the purpose of protecting victims of abuse. 45 | 46 | ### Scope 47 | Yahoo will assign a Response Team member with admin rights on the project and legal rights on the project copyright. The Response Team is empowered to restrict some privileges to the project as needed. Since this project is governed by an open source license, any participant may fork the code under the terms of the project license. The Response Team’s goal is to preserve the project if possible, and will restrict or remove participation from those who disrupt the project. 48 | 49 | This code does not replace the terms of service or acceptable use policies that are provided by the websites used to support this community. Nor does this code apply to communications or actions that take place outside of the context of this community. Many participants in this project are also subject to codes of conduct based on their employment. This code is a social-contract that informs participants of our social expectations. It is not a terms of service or legal contract. 50 | 51 | ## License and Acknowledgment. 52 | This text is shared under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0/). This code is based on a study conducted by the [TODO Group](https://todogroup.org/) of many codes used in the open source community. If you have feedback about this code, contact our Response Team at the address listed above. 53 | -------------------------------------------------------------------------------- /data/splits/SoccerNetGamesTest.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_epl": { 3 | "2014-2015": [ 4 | "2015-05-17 - 18-00 Manchester United 1 - 1 Arsenal" 5 | ], 6 | "2015-2016": [ 7 | "2015-08-16 - 18-00 Manchester City 3 - 0 Chelsea", 8 | "2015-08-23 - 15-30 West Brom 2 - 3 Chelsea", 9 | "2015-08-29 - 17-00 Liverpool 0 - 3 West Ham", 10 | "2015-09-20 - 18-00 Southampton 2 - 3 Manchester United", 11 | "2015-09-26 - 19-30 Newcastle Utd 2 - 2 Chelsea", 12 | "2015-10-03 - 19-30 Chelsea 1 - 3 Southampton", 13 | "2015-10-24 - 17-00 West Ham 2 - 1 Chelsea", 14 | "2015-11-07 - 20-30 Stoke City 1 - 0 Chelsea", 15 | "2015-11-08 - 19-00 Arsenal 1 - 1 Tottenham", 16 | "2015-12-28 - 20-30 Manchester United 0 - 0 Chelsea", 17 | "2016-02-03 - 22-45 Watford 0 - 0 Chelsea", 18 | "2016-03-01 - 22-45 Norwich 1 - 2 Chelsea" 19 | ], 20 | "2016-2017": [ 21 | "2016-08-27 - 14-30 Tottenham 1 - 1 Liverpool", 22 | "2016-09-24 - 14-30 Manchester United 4 - 1 Leicester", 23 | "2016-10-15 - 14-30 Chelsea 3 - 0 Leicester", 24 | "2017-01-21 - 15-30 Liverpool 2 - 3 Swansea", 25 | "2017-05-06 - 17-00 Leicester 3 - 0 Watford" 26 | ] 27 | }, 28 | "europe_uefa-champions-league": { 29 | "2014-2015": [ 30 | "2014-11-04 - 20-00 Zenit Petersburg 1 - 2 Bayer Leverkusen", 31 | "2015-02-24 - 22-45 Manchester City 1 - 2 Barcelona", 32 | "2015-03-10 - 22-45 Real Madrid 3 - 4 Schalke", 33 | "2015-03-17 - 22-45 Monaco 0 - 2 Arsenal", 34 | "2015-04-15 - 21-45 FC Porto 3 - 1 Bayern Munich", 35 | "2015-04-22 - 21-45 Real Madrid 1 - 0 Atl. Madrid", 36 | "2015-05-05 - 21-45 Juventus 2 - 1 Real Madrid" 37 | ], 38 | "2015-2016": [ 39 | "2015-09-29 - 21-45 Bayern Munich 5 - 0 D. Zagreb", 40 | "2015-11-03 - 22-45 Real Madrid 1 - 0 Paris SG", 41 | "2015-11-03 - 22-45 Sevilla 1 - 3 Manchester City", 42 | "2015-11-03 - 22-45 Shakhtar Donetsk 4 - 0 Malmo FF", 43 | "2015-11-25 - 22-45 Shakhtar Donetsk 3 - 4 Real Madrid", 44 | "2016-04-05 - 21-45 Bayern Munich 1 - 0 Benfica" 45 | ], 46 | "2016-2017": [ 47 | "2016-11-01 - 20-45 Besiktas 1 - 1 Napoli", 48 | "2016-11-01 - 22-45 Manchester City 3 - 1 Barcelona", 49 | "2016-11-23 - 22-45 Arsenal 2 - 2 Paris SG", 50 | "2017-03-08 - 22-45 Barcelona 6 - 1 Paris SG", 51 | "2017-04-12 - 21-45 Bayern Munich 1 - 2 Real Madrid", 52 | "2017-05-02 - 21-45 Real Madrid 3 - 0 Atl. Madrid" 53 | ] 54 | }, 55 | "france_ligue-1": { 56 | "2016-2017": [ 57 | "2016-08-28 - 21-45 Monaco 3 - 1 Paris SG", 58 | "2016-11-30 - 23-00 Paris SG 2 - 0 Angers" 59 | ] 60 | }, 61 | "germany_bundesliga": { 62 | "2014-2015": [ 63 | "2015-05-09 - 16-30 Bayern Munich 0 - 1 FC Augsburg" 64 | ], 65 | "2015-2016": [ 66 | "2015-08-29 - 19-30 Bayern Munich 3 - 0 Bayer Leverkusen", 67 | "2015-09-12 - 16-30 Bayern Munich 2 - 1 FC Augsburg", 68 | "2015-10-24 - 16-30 Bayern Munich 4 - 0 FC Koln", 69 | "2015-11-08 - 17-30 Dortmund 3 - 2 Schalke" 70 | ], 71 | "2016-2017": [ 72 | "2016-09-10 - 19-30 RB Leipzig 1 - 0 Dortmund", 73 | "2016-10-01 - 19-30 Bayer Leverkusen 2 - 0 Dortmund", 74 | "2016-11-05 - 17-30 Hamburger SV 2 - 5 Dortmund", 75 | "2016-11-19 - 20-30 Dortmund 1 - 0 Bayern Munich", 76 | "2016-12-16 - 22-30 Hoffenheim 2 - 2 Dortmund", 77 | "2017-01-21 - 17-30 SV Werder Bremen 1 - 2 Dortmund", 78 | "2017-01-29 - 19-30 1. FSV Mainz 05 1 - 1 Dortmund", 79 | "2017-03-04 - 17-30 Dortmund 6 - 2 Bayer Leverkusen", 80 | "2017-04-29 - 16-30 Dortmund 0 - 0 FC Koln" 81 | ] 82 | }, 83 | "italy_serie-a": { 84 | "2014-2015": [ 85 | "2015-04-29 - 21-45 Juventus 3 - 2 Fiorentina" 86 | ], 87 | "2015-2016": [ 88 | "2015-08-29 - 21-45 AC Milan 2 - 1 Empoli", 89 | "2015-09-20 - 16-00 Genoa 0 - 2 Juventus", 90 | "2015-09-27 - 21-45 Inter 1 - 4 Fiorentina" 91 | ], 92 | "2016-2017": [ 93 | "2016-08-27 - 21-45 Napoli 4 - 2 AC Milan", 94 | "2016-09-11 - 16-00 AC Milan 0 - 1 Udinese", 95 | "2016-09-20 - 21-45 AC Milan 2 - 0 Lazio", 96 | "2016-09-24 - 21-45 Napoli 2 - 0 Chievo", 97 | "2016-09-25 - 13-30 Torino 3 - 1 AS Roma", 98 | "2016-09-25 - 21-45 Fiorentina 0 - 0 AC Milan", 99 | "2016-10-02 - 21-45 AS Roma 2 - 1 Inter", 100 | "2016-11-20 - 17-00 Atalanta 2 - 1 AS Roma", 101 | "2016-11-26 - 22-45 Empoli 1 - 4 AC Milan", 102 | "2016-12-04 - 17-00 Lazio 0 - 2 AS Roma", 103 | "2017-01-08 - 17-00 Genoa 0 - 1 AS Roma", 104 | "2017-01-29 - 17-00 Sampdoria 3 - 2 AS Roma", 105 | "2017-02-07 - 22-45 AS Roma 4 - 0 Fiorentina", 106 | "2017-02-25 - 20-00 Napoli 0 - 2 Atalanta", 107 | "2017-02-26 - 22-45 Inter 1 - 3 AS Roma", 108 | "2017-04-01 - 21-45 AS Roma 2 - 0 Empoli", 109 | "2017-04-30 - 21-45 Inter 0 - 1 Napoli", 110 | "2017-05-20 - 21-45 Napoli 4 - 1 Fiorentina" 111 | ] 112 | }, 113 | "spain_laliga": { 114 | "2014-2015": [ 115 | "2015-02-14 - 20-00 Real Madrid 2 - 0 Dep. La Coruna", 116 | "2015-04-18 - 21-00 Real Madrid 3 - 1 Malaga", 117 | "2015-04-25 - 17-00 Espanyol 0 - 2 Barcelona", 118 | "2015-04-29 - 21-00 Real Madrid 3 - 0 Almeria", 119 | "2015-05-02 - 17-00 Cordoba 0 - 8 Barcelona", 120 | "2015-05-09 - 19-00 Barcelona 2 - 0 Real Sociedad" 121 | ], 122 | "2015-2016": [ 123 | "2015-08-29 - 23-30 Real Madrid 5 - 0 Betis", 124 | "2015-09-19 - 17-00 Real Madrid 1 - 0 Granada CF", 125 | "2015-11-08 - 18-00 Barcelona 3 - 0 Villarreal", 126 | "2015-12-05 - 18-00 Real Madrid 4 - 1 Getafe", 127 | "2015-12-30 - 18-00 Real Madrid 3 - 1 Real Sociedad", 128 | "2016-02-27 - 18-00 Real Madrid 0 - 1 Atl. Madrid", 129 | "2016-03-02 - 23-00 Levante 1 - 3 Real Madrid", 130 | "2016-05-08 - 18-00 Real Madrid 3 - 2 Valencia", 131 | "2016-05-14 - 18-00 Dep. La Coruna 0 - 2 Real Madrid" 132 | ], 133 | "2016-2017": [ 134 | "2016-09-10 - 21-30 Barcelona 1 - 2 Alaves", 135 | "2016-09-21 - 21-00 Real Madrid 1 - 1 Villarreal", 136 | "2016-09-24 - 21-45 Las Palmas 2 - 2 Real Madrid", 137 | "2016-11-26 - 18-15 Real Madrid 2 - 1 Gijon", 138 | "2016-12-18 - 22-45 Barcelona 4 - 1 Espanyol", 139 | "2017-02-11 - 22-45 Osasuna 1 - 3 Real Madrid", 140 | "2017-03-12 - 22-45 Real Madrid 2 - 1 Betis", 141 | "2017-04-02 - 17-15 Real Madrid 3 - 0 Alaves", 142 | "2017-04-08 - 21-45 Malaga 2 - 0 Barcelona", 143 | "2017-04-26 - 20-30 Barcelona 7 - 1 Osasuna" 144 | ] 145 | } 146 | } -------------------------------------------------------------------------------- /data/splits/SoccerNetGamesValid.json: -------------------------------------------------------------------------------- 1 | { 2 | "england_epl": { 3 | "2014-2015": [ 4 | "2015-04-11 - 19-30 Burnley 0 - 1 Arsenal" 5 | ], 6 | "2015-2016": [ 7 | "2015-08-30 - 18-00 Swansea 2 - 1 Manchester United", 8 | "2015-09-26 - 17-00 Leicester 2 - 5 Arsenal", 9 | "2015-09-26 - 17-00 Manchester United 3 - 0 Sunderland", 10 | "2015-10-03 - 17-00 Manchester City 6 - 1 Newcastle Utd", 11 | "2015-12-26 - 18-00 Chelsea 2 - 2 Watford", 12 | "2016-01-23 - 20-30 West Ham 2 - 2 Manchester City", 13 | "2016-01-24 - 19-00 Arsenal 0 - 1 Chelsea", 14 | "2016-02-13 - 20-30 Chelsea 5 - 1 Newcastle Utd", 15 | "2016-02-27 - 18-00 Southampton 1 - 2 Chelsea", 16 | "2016-03-05 - 18-00 Manchester City 4 - 0 Aston Villa", 17 | "2016-03-20 - 19-00 Manchester City 0 - 1 Manchester United", 18 | "2016-04-23 - 17-00 Bournemouth 1 - 4 Chelsea" 19 | ], 20 | "2016-2017": [ 21 | "2016-10-01 - 14-30 Swansea 1 - 2 Liverpool", 22 | "2016-10-02 - 18-30 Burnley 0 - 1 Arsenal", 23 | "2016-10-22 - 17-00 Arsenal 0 - 0 Middlesbrough", 24 | "2016-10-29 - 19-30 Crystal Palace 2 - 4 Liverpool", 25 | "2016-12-04 - 16-30 Bournemouth 4 - 3 Liverpool", 26 | "2017-04-09 - 18-00 Everton 4 - 2 Leicester" 27 | ] 28 | }, 29 | "europe_uefa-champions-league": { 30 | "2014-2015": [ 31 | "2014-11-04 - 22-45 Real Madrid 1 - 0 Liverpool", 32 | "2014-11-05 - 22-45 Ajax 0 - 2 Barcelona", 33 | "2014-11-05 - 22-45 Manchester City 1 - 2 CSKA Moscow", 34 | "2014-12-09 - 22-45 Galatasaray 1 - 4 Arsenal", 35 | "2014-12-09 - 22-45 Real Madrid 4 - 0 Ludogorets", 36 | "2015-03-11 - 22-45 Bayern Munich 7 - 0 Shakhtar Donetsk", 37 | "2015-04-15 - 21-45 Paris SG 1 - 3 Barcelona", 38 | "2015-04-21 - 21-45 Barcelona 2 - 0 Paris SG" 39 | ], 40 | "2015-2016": [ 41 | "2015-09-15 - 21-45 PSV 2 - 1 Manchester United", 42 | "2015-09-16 - 21-45 Chelsea 4 - 0 Maccabi Tel Aviv", 43 | "2015-09-16 - 21-45 Dyn. Kiev 2 - 2 FC Porto", 44 | "2015-09-16 - 21-45 Olympiakos Piraeus 0 - 3 Bayern Munich", 45 | "2015-09-29 - 21-45 Barcelona 2 - 1 Bayer Leverkusen", 46 | "2015-09-29 - 21-45 FC Porto 2 - 1 Chelsea", 47 | "2015-11-03 - 22-45 Benfica 2 - 1 Galatasaray", 48 | "2015-11-04 - 22-45 Barcelona 3 - 0 BATE", 49 | "2015-11-24 - 20-00 BATE 1 - 1 Bayer Leverkusen", 50 | "2015-11-25 - 22-45 Juventus 1 - 0 Manchester City" 51 | ], 52 | "2016-2017": [ 53 | "2016-09-28 - 21-45 Napoli 4 - 2 Benfica", 54 | "2016-10-19 - 21-45 Paris SG 3 - 0 Basel" 55 | ] 56 | }, 57 | "france_ligue-1": { 58 | "2015-2016": [ 59 | "2015-09-19 - 18-30 Reims 1 - 1 Paris SG" 60 | ], 61 | "2016-2017": [ 62 | "2016-08-21 - 21-45 Paris SG 3 - 0 Metz", 63 | "2016-09-09 - 21-45 Paris SG 1 - 1 St Etienne", 64 | "2016-09-20 - 22-00 Paris SG 3 - 0 Dijon", 65 | "2016-10-23 - 21-45 Paris SG 0 - 0 Marseille", 66 | "2016-12-03 - 19-00 Montpellier 3 - 0 Paris SG", 67 | "2017-03-04 - 19-00 Paris SG 1 - 0 Nancy", 68 | "2017-04-09 - 22-00 Paris SG 4 - 0 Guingamp", 69 | "2017-05-14 - 22-00 St Etienne 0 - 5 Paris SG" 70 | ] 71 | }, 72 | "germany_bundesliga": { 73 | "2014-2015": [ 74 | "2015-04-25 - 19-30 Bayern Munich 1 - 0 Hertha Berlin", 75 | "2015-05-02 - 16-30 Hoffenheim 1 - 1 Dortmund" 76 | ], 77 | "2015-2016": [ 78 | "2015-09-20 - 18-30 Dortmund 3 - 0 Bayer Leverkusen", 79 | "2015-10-04 - 18-30 Bayern Munich 5 - 1 Dortmund", 80 | "2015-11-07 - 17-30 Bayern Munich 4 - 0 VfB Stuttgart", 81 | "2016-04-23 - 16-30 Hertha Berlin 0 - 2 Bayern Munich" 82 | ], 83 | "2016-2017": [ 84 | "2016-12-03 - 17-30 Dortmund 4 - 1 B. Monchengladbach", 85 | "2017-02-25 - 17-30 SC Freiburg 0 - 3 Dortmund" 86 | ] 87 | }, 88 | "italy_serie-a": { 89 | "2014-2015": [ 90 | "2015-04-25 - 19-00 Udinese 2 - 1 AC Milan", 91 | "2015-05-10 - 21-45 Lazio 1 - 2 Inter", 92 | "2015-05-17 - 13-30 Sassuolo 3 - 2 AC Milan" 93 | ], 94 | "2015-2016": [ 95 | "2015-09-26 - 21-45 Napoli 2 - 1 Juventus" 96 | ], 97 | "2016-2017": [ 98 | "2016-08-20 - 19-00 AS Roma 4 - 0 Udinese", 99 | "2016-09-18 - 21-45 Fiorentina 1 - 0 AS Roma", 100 | "2016-09-21 - 21-45 Genoa 0 - 0 Napoli", 101 | "2016-10-02 - 19-00 AC Milan 4 - 3 Sassuolo", 102 | "2016-10-29 - 21-45 Juventus 2 - 1 Napoli", 103 | "2016-11-05 - 22-45 Napoli 1 - 1 Lazio", 104 | "2016-12-12 - 23-00 AS Roma 1 - 0 AC Milan", 105 | "2017-01-22 - 22-45 AS Roma 1 - 0 Cagliari", 106 | "2017-02-19 - 17-00 Chievo 1 - 3 Napoli", 107 | "2017-03-12 - 22-45 Palermo 0 - 3 AS Roma", 108 | "2017-04-02 - 21-45 Napoli 1 - 1 Juventus", 109 | "2017-04-15 - 21-45 Napoli 3 - 0 Udinese", 110 | "2017-04-30 - 13-30 AS Roma 1 - 3 Lazio", 111 | "2017-05-06 - 19-00 Napoli 3 - 1 Cagliari" 112 | ] 113 | }, 114 | "spain_laliga": { 115 | "2014-2015": [ 116 | "2015-04-11 - 17-00 Real Madrid 3 - 0 Eibar", 117 | "2015-04-11 - 21-00 Sevilla 2 - 2 Barcelona", 118 | "2015-04-28 - 21-00 Barcelona 6 - 0 Getafe", 119 | "2015-05-02 - 19-00 Atl. Madrid 0 - 0 Ath Bilbao", 120 | "2015-05-02 - 21-00 Sevilla 2 - 3 Real Madrid", 121 | "2015-05-23 - 21-30 Real Madrid 7 - 3 Getafe" 122 | ], 123 | "2015-2016": [ 124 | "2015-09-12 - 21-30 Atl. Madrid 1 - 2 Barcelona", 125 | "2015-09-26 - 17-00 Barcelona 2 - 1 Las Palmas", 126 | "2015-10-24 - 17-00 Celta Vigo 1 - 3 Real Madrid", 127 | "2015-12-13 - 22-30 Villarreal 1 - 0 Real Madrid", 128 | "2015-12-20 - 18-00 Real Madrid 1 - 0 Rayo Vallecano", 129 | "2016-01-03 - 22-30 Valencia 2 - 2 Real Madrid", 130 | "2016-01-09 - 22-30 Real Madrid 5 - 0 Dep. La Coruna", 131 | "2016-02-13 - 18-00 Real Madrid 4 - 2 Ath Bilbao", 132 | "2016-03-20 - 22-30 Real Madrid 4 - 0 Sevilla" 133 | ], 134 | "2016-2017": [ 135 | "2016-08-21 - 21-15 Real Sociedad 0 - 3 Real Madrid", 136 | "2016-09-17 - 14-00 Leganes 1 - 5 Barcelona", 137 | "2016-10-02 - 21-45 Celta Vigo 4 - 3 Barcelona", 138 | "2016-10-15 - 17-15 Barcelona 4 - 0 Dep. La Coruna", 139 | "2017-01-07 - 15-00 Real Madrid 5 - 0 Granada CF", 140 | "2017-01-21 - 18-15 Real Madrid 2 - 1 Malaga", 141 | "2017-02-18 - 18-15 Real Madrid 2 - 0 Espanyol", 142 | "2017-04-23 - 21-45 Real Madrid 2 - 3 Barcelona", 143 | "2017-04-29 - 17-15 Real Madrid 2 - 1 Valencia", 144 | "2017-04-29 - 21-45 Espanyol 0 - 3 Barcelona", 145 | "2017-05-14 - 21-00 Las Palmas 1 - 4 Barcelona" 146 | ] 147 | } 148 | } -------------------------------------------------------------------------------- /bin/transform_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import pickle 9 | from pathlib import Path 10 | from typing import Dict, List, Optional, Any 11 | 12 | import numpy as np 13 | from scipy.interpolate import interp1d 14 | 15 | from spivak.application.feature_utils import make_output_directories, \ 16 | VideoFeatureInfo, create_video_feature_infos 17 | from spivak.models.delta_dense_predictor import clip_frames 18 | 19 | RESAMPLING_ZEROS = "zeros" 20 | RESAMPLING_REPEAT = "repeat" 21 | RESAMPLING_INTERPOLATE = "interpolate" 22 | RESAMPLING_MAX = "max" 23 | IDENTITY_NORMALIZER = "identity" 24 | 25 | 26 | class Args: 27 | INPUT_DIRS = "input_dirs" 28 | INPUT_FEATURE_NAMES = "input_feature_names" 29 | OUTPUT_DIR = "output_dir" 30 | OUTPUT_FEATURE_NAME = "output_feature_name" 31 | FACTORS = "factors" 32 | NORMALIZERS = "normalizers" 33 | RESAMPLING = "resampling" 34 | 35 | 36 | def main() -> None: 37 | args = _get_command_line_arguments() 38 | input_dirs = [Path(p) for p in args[Args.INPUT_DIRS]] 39 | for input_dir in input_dirs: 40 | if not input_dir.is_dir(): 41 | raise ValueError(f"Input directory failed is_dir(): {input_dir}") 42 | input_feature_names = args[Args.INPUT_FEATURE_NAMES] 43 | normalizers = _read_normalizers( 44 | args[Args.NORMALIZERS], len(input_feature_names)) 45 | factors = args[Args.FACTORS] 46 | resampling = args[Args.RESAMPLING] 47 | output_feature_name = args[Args.OUTPUT_FEATURE_NAME] 48 | output_dir = Path(args[Args.OUTPUT_DIR]) 49 | output_dir.mkdir(parents=True, exist_ok=True) 50 | video_feature_infos = create_video_feature_infos( 51 | input_dirs, input_feature_names, output_dir, output_feature_name) 52 | print(f"Found {len(video_feature_infos)} video feature files") 53 | make_output_directories(video_feature_infos) 54 | for video_feature_info in video_feature_infos: 55 | _transform_features_file( 56 | video_feature_info, normalizers, factors, resampling) 57 | 58 | 59 | def _read_normalizers( 60 | normalizers_arg: Optional[List[str]], n_features: int) -> List: 61 | if normalizers_arg is None: 62 | return [None] * n_features 63 | assert len(normalizers_arg) == n_features 64 | return [ 65 | _read_normalizer(normalizer_arg) for normalizer_arg in normalizers_arg] 66 | 67 | 68 | def _read_normalizer(normalizer_arg: str) -> Any: 69 | if normalizer_arg == IDENTITY_NORMALIZER: 70 | return None 71 | with Path(normalizer_arg).open("rb") as pickle_file: 72 | return pickle.load(pickle_file) 73 | 74 | 75 | def _transform_features_file( 76 | video_feature_info: VideoFeatureInfo, normalizers: List[Any], 77 | factors: List[float], resampling: str) -> None: 78 | # Read the features from video_feature_info.input_paths, resample them using 79 | # the specified factors and the provided resampling strategy, concatenate 80 | # the results, then save to numpy output file 81 | # video_feature_info.output_path. 82 | transformed_list = [] 83 | for input_path, normalizer, factor in zip( 84 | video_feature_info.input_paths, normalizers, factors): 85 | original = np.load(str(input_path)) 86 | if normalizer: 87 | features_to_resample = normalizer.transform(original) 88 | else: 89 | features_to_resample = original 90 | transformed = _resample_features( 91 | features_to_resample, factor, resampling) 92 | transformed_list.append(transformed) 93 | min_num_frames = min( 94 | transformed.shape[0] for transformed in transformed_list) 95 | transformed_clipped_list = [ 96 | clip_frames(transformed, min_num_frames) 97 | for transformed in transformed_list] 98 | transformed_features = np.concatenate(transformed_clipped_list, axis=1) 99 | print(f"Writing transformed features to {video_feature_info.output_path}") 100 | np.save(str(video_feature_info.output_path), transformed_features) 101 | 102 | 103 | def _resample_features( 104 | original: np.ndarray, factor: float, resampling: str) -> np.ndarray: 105 | if factor == 1.0: 106 | return original 107 | if resampling == RESAMPLING_INTERPOLATE: 108 | n_original_times = len(original) 109 | # We don't know the actual timestamps or frequency of the original 110 | # features, so we just set the time step to 1.0 (starting at 1.0), 111 | # since what really matters are the relative time values. 112 | original_times = np.linspace(1.0, n_original_times, n_original_times) 113 | interpolation = interp1d( 114 | original_times, original, axis=0, copy=False, kind="linear", 115 | bounds_error=False, fill_value="extrapolate") 116 | desired_times = np.linspace( 117 | 1.0 / factor, n_original_times, round(factor * n_original_times)) 118 | transformed = interpolation(desired_times) 119 | elif resampling == RESAMPLING_ZEROS: 120 | raise NotImplementedError() 121 | elif resampling == RESAMPLING_REPEAT: 122 | raise NotImplementedError() 123 | elif resampling == RESAMPLING_MAX: 124 | raise NotImplementedError() 125 | else: 126 | raise ValueError(f"Unknown resampling: {resampling}") 127 | return transformed 128 | 129 | 130 | def _get_command_line_arguments() -> Dict: 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument( 133 | "--" + Args.INPUT_DIRS, 134 | help='One or more input directories containing features', 135 | nargs="+", required=True, type=str) 136 | parser.add_argument( 137 | "--" + Args.INPUT_FEATURE_NAMES, 138 | help='One or more feature file name endings (e.g. ResNET_TF2, without ' 139 | 'the .npy part)', nargs="+", required=True, type=str) 140 | parser.add_argument( 141 | "--" + Args.OUTPUT_DIR, required=True, 142 | help="Directory for the output features", type=str) 143 | parser.add_argument( 144 | "--" + Args.OUTPUT_FEATURE_NAME, required=True, 145 | help="Name of the output feature for feature file name endings " 146 | "(e.g. ResNET_TF2, without the .npy part)", type=str) 147 | parser.add_argument( 148 | "--" + Args.NORMALIZERS, required=False, help=( 149 | "One or more pickle files containing normalizers for the features. " 150 | f"Use \"{IDENTITY_NORMALIZER}\" to apply no normalization to " 151 | f"the corresponding feature. Normalization is applied before " 152 | f"resampling."), type=str, nargs="+") 153 | parser.add_argument( 154 | "--" + Args.FACTORS, required=True, 155 | help="One or more factors (floats) for the new sampling rates", 156 | type=float, nargs="+") 157 | parser.add_argument( 158 | "--" + Args.RESAMPLING, required=True, 159 | help="How to resample the features to achieve higher sampling rates", 160 | choices=[ 161 | RESAMPLING_ZEROS, RESAMPLING_INTERPOLATE, RESAMPLING_REPEAT, 162 | RESAMPLING_MAX], type=str) 163 | args_dict = vars(parser.parse_args()) 164 | return args_dict 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /spivak/models/assembly/weight_creator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023, Yahoo Inc. 2 | # Licensed under the Apache License, Version 2.0. 3 | # See the accompanying LICENSE file for terms. 4 | 5 | import csv 6 | import math 7 | from abc import ABCMeta, abstractmethod 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | import tensorflow_probability as tfp 13 | 14 | from spivak.data.label_map import LabelMap 15 | 16 | CLASS_WEIGHTS_COLUMN_LABEL = "label" 17 | CLASS_WEIGHTS_COLUMN_WEIGHT = "weight" 18 | # EXPECTED_POSITIVE_RATE should be more or less the expected number of positive 19 | # labels in a video divided by the number of frames in the video. I'm not 20 | # sure that we'll really use this, as it seems the negative_sampling_rate 21 | # parameter to create_frame_weights() works better when set to 1.0, which means 22 | # we always end up using all samples anyway. 23 | # For SoccerNet V1, there was around 1 action every seven minutes, while for 24 | # SoccerNet V2, there were around 2 actions per minute. 25 | # EXPECTED_POSITIVE_RATE = 1.0 / (7.0 * 60.0 * 2.0) # V1 26 | EXPECTED_POSITIVE_RATE = 2.0 / (1.0 * 60.0 * 2.0) # V2 27 | 28 | 29 | class WeightCreatorInterface(metaclass=ABCMeta): 30 | 31 | @abstractmethod 32 | def video_weight_inputs(self, video_labels, video_targets): 33 | pass 34 | 35 | @abstractmethod 36 | def tf_chunk_weights(self, chunk_weight_inputs): 37 | pass 38 | 39 | @abstractmethod 40 | def chunk_weights(self, chunk_weight_inputs): 41 | pass 42 | 43 | 44 | class IdentityWeightCreator(WeightCreatorInterface): 45 | 46 | def __init__(self, class_weights: np.ndarray) -> None: 47 | self._class_weights = class_weights 48 | 49 | def video_weight_inputs( 50 | self, video_labels: np.ndarray, 51 | video_targets: np.ndarray) -> np.ndarray: 52 | num_frames = video_labels.shape[0] 53 | return np.tile(self._class_weights, (num_frames, 1)) 54 | 55 | def tf_chunk_weights(self, chunk_weight_inputs): 56 | return chunk_weight_inputs 57 | 58 | def chunk_weights(self, chunk_weight_inputs): 59 | return chunk_weight_inputs 60 | 61 | 62 | class SampledWeightCreator(WeightCreatorInterface): 63 | 64 | def __init__( 65 | self, weight_radius: float, negative_sampling_rate: float) -> None: 66 | # TODO: maybe add support for input class weights. 67 | self._weight_radius = weight_radius 68 | self._negative_sampling_rate = negative_sampling_rate 69 | expected_selection_rate = _compute_expected_selection_rate( 70 | self._weight_radius, self._negative_sampling_rate) 71 | self._normalizer = expected_selection_rate 72 | 73 | def video_weight_inputs(self, video_labels, video_targets): 74 | non_zeros = np.nonzero(video_labels) 75 | shape = video_labels.shape 76 | return _sampled_video_weight_inputs( 77 | non_zeros, self._weight_radius, shape) 78 | 79 | def tf_chunk_weights(self, chunk_weight_inputs): 80 | return _tf_sampled_chunk_weights( 81 | chunk_weight_inputs, self._negative_sampling_rate, self._normalizer) 82 | 83 | def chunk_weights(self, chunk_weight_inputs): 84 | return _sampled_chunk_weights( 85 | chunk_weight_inputs, self._negative_sampling_rate, self._normalizer) 86 | 87 | 88 | def read_class_weights( 89 | class_weights_path: Path, label_map: LabelMap) -> np.ndarray: 90 | if not class_weights_path.exists(): 91 | return np.ones(label_map.num_classes()) 92 | num_classes = label_map.num_classes() 93 | class_weights = np.empty(num_classes) 94 | # Set everything to nan so we can later make sure all values were set. 95 | class_weights[:] = np.nan 96 | with class_weights_path.open("r") as csv_file: 97 | reader = csv.DictReader(csv_file) 98 | for row in reader: 99 | label = row[CLASS_WEIGHTS_COLUMN_LABEL] 100 | label_int = label_map.label_to_int[label] 101 | class_weights[label_int] = row[CLASS_WEIGHTS_COLUMN_WEIGHT] 102 | # Make sure all values have been filled in. 103 | assert not np.any(np.isnan(class_weights)) 104 | return class_weights 105 | 106 | 107 | def create_frame_range( 108 | frame_index: int, radius: float, n_frames: int) -> np.ndarray: 109 | start = max(math.ceil(frame_index - radius), 0) 110 | end = min(math.floor(frame_index + radius) + 1, n_frames) 111 | return np.arange(start, end) 112 | 113 | 114 | def _compute_expected_selection_rate( 115 | radius: float, negative_sampling_rate: float) -> float: 116 | # TODO: If we're going to actually use this code flow, compute it based on a 117 | # set of videos, instead of trying to guess the positive_selection_rate, 118 | # which does not account for overlaps. 119 | # We want the expected selection rate to be independent of 120 | # the size of the particular input videos, so that we don't reduce the 121 | # overall frame weights on larger videos. In other words, we would like 122 | # larger videos to have a larger influence over the weights. 123 | positive_rate = EXPECTED_POSITIVE_RATE 124 | positive_selection_rate = min(1.0, 2 * radius * positive_rate) 125 | return ((1.0 - positive_selection_rate) * negative_sampling_rate 126 | + positive_selection_rate * 1.0) 127 | 128 | 129 | def _sampled_video_weight_inputs(non_zeros, radius: float, shape): 130 | num_frames, num_classes = shape 131 | weights = np.zeros((num_frames, num_classes)) 132 | frame_indexes, class_indexes = non_zeros 133 | for frame_index, class_index in zip(frame_indexes, class_indexes): 134 | # Set the weights to in a window around the positive example. 135 | weight_range = create_frame_range(frame_index, radius, num_frames) 136 | weights[weight_range, class_index] = 1.0 137 | return weights 138 | 139 | 140 | def _tf_sampled_chunk_weights( 141 | chunk_frame_weight_windows, negative_sampling_rate, normalizer): 142 | random_negative_frames = _tf_sample_negative_frames( 143 | negative_sampling_rate, tf.shape(chunk_frame_weight_windows)) 144 | selected_frames = tf.maximum( 145 | chunk_frame_weight_windows, random_negative_frames) 146 | return selected_frames / normalizer 147 | 148 | 149 | def _sampled_chunk_weights( 150 | frame_weight_windows: np.ndarray, negative_sampling_rate: float, 151 | normalizer: float): 152 | selected_frames = _sample_negative_frames( 153 | negative_sampling_rate, frame_weight_windows.shape) 154 | selected_frames[frame_weight_windows > 0] = 1.0 155 | return selected_frames / normalizer 156 | 157 | 158 | def _tf_sample_negative_frames(negative_sampling_rate: float, shape): 159 | if negative_sampling_rate == 1.0: 160 | return tf.ones(shape) 161 | elif negative_sampling_rate == 0.0: 162 | return tf.zeros(shape) 163 | else: 164 | return tfp.distributions.Bernoulli( 165 | probs=negative_sampling_rate, dtype=tf.float32).sample(shape) 166 | 167 | 168 | def _sample_negative_frames(negative_sampling_rate, shape): 169 | if negative_sampling_rate == 1.0: 170 | return np.ones(shape) 171 | elif negative_sampling_rate == 0.0: 172 | return np.zeros(shape) 173 | else: 174 | return np.random.binomial(n=1, p=negative_sampling_rate, size=shape) 175 | -------------------------------------------------------------------------------- /bin/evaluate_spotting_jsons.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright 2023, Yahoo Inc. 4 | # Licensed under the Apache License, Version 2.0. 5 | # See the accompanying LICENSE file for terms. 6 | 7 | import argparse 8 | import logging 9 | from pathlib import Path 10 | from typing import Dict, List, Tuple 11 | 12 | import numpy as np 13 | 14 | from spivak.application.dataset_creation import read_spotting_label_map 15 | from spivak.data.dataset import Task, INDEX_VALID, INDEX_LABELS, read_num_frames 16 | from spivak.data.dataset_splits import SPLIT_KEY_TEST 17 | from spivak.data.soccernet_label_io import GameSpottingPredictionsReader, \ 18 | SOCCERNET_TYPE_TWO 19 | from spivak.data.soccernet_reader import GameOneHotSpottingLabelReader, \ 20 | GamePathsReader 21 | from spivak.evaluation.aggregate import EvaluationAggregate 22 | from spivak.evaluation.spotting_evaluation import run_spotting_evaluation, \ 23 | SpottingEvaluation, read_tolerances_config 24 | 25 | # Command-line arguments. 26 | ARGS_RESULTS_DIR = "results_dir" 27 | ARGS_CONFIG_DIR = "config_dir" 28 | ARGS_LABELS_DIR = "labels_dir" 29 | ARGS_SPLITS_DIR = "splits_dir" 30 | ARGS_OUTPUT_DIR = "output_dir" 31 | ARGS_FEATURES_DIR = "features_dir" 32 | # Other constants. 33 | EVALUATION_FRAME_RATE = 2.0 34 | EVALUATION_FEATURE_NAME = "ResNET_TF2_PCA512" 35 | SOCCERNET_TYPE = SOCCERNET_TYPE_TWO 36 | SLACK_SECONDS = 1.0 37 | REPRODUCE_SOCCERNET_EVALUATION = True 38 | 39 | 40 | def main() -> None: 41 | logging.getLogger().setLevel(logging.DEBUG) 42 | args = _get_command_line_arguments() 43 | config_dir = Path(args[ARGS_CONFIG_DIR]) 44 | spotting_label_map = read_spotting_label_map(config_dir) 45 | if not spotting_label_map: 46 | raise ValueError( 47 | f"Could not read spotting label map from configuration dir " 48 | f"{config_dir}.") 49 | logging.info("Going to read detections and labels") 50 | detections, labels = _read_spotting_detections_and_labels( 51 | args, spotting_label_map.num_classes()) 52 | tolerances_config = read_tolerances_config(config_dir) 53 | logging.info("Going to run spotting evaluation") 54 | spotting_evaluation = run_spotting_evaluation( 55 | detections, labels, tolerances_config, EVALUATION_FRAME_RATE, 56 | spotting_label_map.num_classes(), prune_classes=False, 57 | create_confusion_data_frame=True, label_map=spotting_label_map) 58 | output_dir = Path(args[ARGS_OUTPUT_DIR]) 59 | _save_and_log_spotting_evaluation(spotting_evaluation, output_dir) 60 | 61 | 62 | def _read_spotting_detections_and_labels( 63 | args: Dict, num_classes: int 64 | ) -> Tuple[List[np.ndarray], List[np.ndarray]]: 65 | splits_dir = Path(args[ARGS_SPLITS_DIR]) 66 | results_dir = Path(args[ARGS_RESULTS_DIR]) 67 | labels_dir = Path(args[ARGS_LABELS_DIR]) 68 | features_dir = Path(args[ARGS_FEATURES_DIR]) 69 | game_one_hot_label_reader = GameOneHotSpottingLabelReader( 70 | SOCCERNET_TYPE, EVALUATION_FRAME_RATE, num_classes) 71 | game_predictions_reader = GameSpottingPredictionsReader( 72 | SOCCERNET_TYPE, EVALUATION_FRAME_RATE, num_classes) 73 | # For evaluation purposes, we currently get the video lengths from the 74 | # feature files, so we need to know the features_dir, feature_type and 75 | # frame-rate here. 76 | game_paths_reader = GamePathsReader( 77 | SOCCERNET_TYPE, EVALUATION_FEATURE_NAME, features_dir, labels_dir, 78 | splits_dir) 79 | valid_game_paths = game_paths_reader.read_valid(SPLIT_KEY_TEST) 80 | detections = [] 81 | labels = [] 82 | for game_paths in valid_game_paths: 83 | # Read the number of frames for each video. 84 | num_video_frames_one = read_num_frames(game_paths.features_one) 85 | num_video_frames_two = read_num_frames(game_paths.features_two) 86 | num_label_frames_one = num_video_frames_one 87 | num_label_frames_two = num_video_frames_two 88 | if REPRODUCE_SOCCERNET_EVALUATION: 89 | # In the SoccerNet evaluation code, they don't bother to figure 90 | # out how many frames are needed, and just use a very large 91 | # number. This allows labels that are slightly out of bounds to not 92 | # have to be pushed into bounds, yielding very slightly different 93 | # results. We add a bit of slack here on the size of the label 94 | # matrices in order to match their results. 95 | frame_slack = int(EVALUATION_FRAME_RATE * SLACK_SECONDS) 96 | num_label_frames_one += frame_slack 97 | num_label_frames_two += frame_slack 98 | # Read the labels for the two videos. 99 | labels_and_valid_one, labels_and_valid_two = \ 100 | game_one_hot_label_reader.read( 101 | game_paths.labels.get(Task.SPOTTING), 102 | num_label_frames_one, num_label_frames_two) 103 | assert labels_and_valid_one[INDEX_VALID] 104 | assert labels_and_valid_two[INDEX_VALID] 105 | labels.append(labels_and_valid_one[INDEX_LABELS]) 106 | labels.append(labels_and_valid_two[INDEX_LABELS]) 107 | # Read the detections for the two videos. 108 | detections_dir = results_dir / game_paths.relative 109 | detections_path = _spotting_detections_path(detections_dir) 110 | detections_one, detections_two = game_predictions_reader.read( 111 | detections_path, num_video_frames_one, num_video_frames_two) 112 | detections.append(detections_one) 113 | detections.append(detections_two) 114 | return detections, labels 115 | 116 | 117 | def _spotting_detections_path(detections_dir: Path) -> Path: 118 | glob_result = list(detections_dir.glob("*.json")) 119 | assert len(glob_result) == 1 120 | return glob_result[0] 121 | 122 | 123 | def _save_and_log_spotting_evaluation( 124 | spotting_evaluation: SpottingEvaluation, output_dir: Path): 125 | logging.info(f"Average-mAP V2: {spotting_evaluation.average_map_dict}") 126 | evaluation_dir = output_dir / "Evaluation" 127 | evaluation_dir.mkdir(exist_ok=True, parents=True) 128 | evaluation = EvaluationAggregate( 129 | spotting_evaluation, segmentation_evaluation=None) 130 | logging.info("Evaluation result") 131 | logging.info(evaluation) 132 | evaluation.save_txt(str(evaluation_dir)) 133 | evaluation.save_pkl(str(evaluation_dir)) 134 | 135 | 136 | def _get_command_line_arguments() -> Dict: 137 | parser = argparse.ArgumentParser() 138 | parser.add_argument( 139 | "--" + ARGS_RESULTS_DIR, required=True, 140 | help='Input directory containing JSON results') 141 | parser.add_argument( 142 | "--" + ARGS_FEATURES_DIR, required=True, 143 | help='Input directory containing features') 144 | parser.add_argument( 145 | "--" + ARGS_LABELS_DIR, required=True, 146 | help='Input directory containing labels') 147 | parser.add_argument( 148 | "--" + ARGS_SPLITS_DIR, required=True, 149 | help="Directory containing file(s) with splits definitions.") 150 | parser.add_argument( 151 | "--" + ARGS_OUTPUT_DIR, required=True, 152 | help="Output directory, for saving the evaluation result files.") 153 | parser.add_argument( 154 | "--" + ARGS_CONFIG_DIR, required=True, 155 | help="Directory with configuration files") 156 | args_dict = vars(parser.parse_args()) 157 | return args_dict 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | --------------------------------------------------------------------------------