├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── img └── jabs_screenshot.png ├── launch_jabs.bat ├── poetry.lock ├── pyproject.toml ├── ruff.toml ├── setup_windows.bat ├── src ├── __init__.py └── jabs │ ├── __init__.py │ ├── __main__.py │ ├── classifier │ ├── __init__.py │ └── classifier.py │ ├── cli │ ├── __init__.py │ └── progress_bar.py │ ├── constants.py │ ├── feature_extraction │ ├── __init__.py │ ├── angle_index.py │ ├── base_features │ │ ├── __init__.py │ │ ├── angles.py │ │ ├── angular_velocity.py │ │ ├── base_group.py │ │ ├── centroid_velocity.py │ │ ├── pairwise_distances.py │ │ ├── point_speeds.py │ │ └── point_velocities.py │ ├── feature_base_class.py │ ├── feature_group_base_class.py │ ├── features.py │ ├── landmark_features │ │ ├── __init__.py │ │ ├── corner.py │ │ ├── food_hopper.py │ │ ├── landmark_group.py │ │ └── lixit.py │ ├── segmentation_features │ │ ├── __init__.py │ │ ├── hu_moments.py │ │ ├── moment_cache.py │ │ ├── moments.py │ │ ├── segment_group.py │ │ └── shape_descriptors.py │ ├── social_features │ │ ├── __init__.py │ │ ├── closest_distances.py │ │ ├── closest_fov_angles.py │ │ ├── closest_fov_distances.py │ │ ├── pairwise_social_distances.py │ │ ├── social_distance.py │ │ └── social_group.py │ └── window_operations │ │ ├── __init__.py │ │ ├── signal_stats.py │ │ └── window_stats.py │ ├── pose_estimation │ ├── __init__.py │ ├── pose_est.py │ ├── pose_est_v2.py │ ├── pose_est_v3.py │ ├── pose_est_v4.py │ ├── pose_est_v5.py │ └── pose_est_v6.py │ ├── project │ ├── __init__.py │ ├── export_training.py │ ├── feature_manager.py │ ├── prediction_manager.py │ ├── project.py │ ├── project_paths.py │ ├── project_utils.py │ ├── read_training.py │ ├── settings_manager.py │ ├── track_labels.py │ ├── video_labels.py │ └── video_manager.py │ ├── resources │ ├── __init__.py │ ├── docs │ │ ├── features │ │ │ └── features.md │ │ └── user_guide │ │ │ ├── imgs │ │ │ ├── StackedImgs.svg │ │ │ ├── classifier_controls.png │ │ │ ├── identity_gaps.png │ │ │ ├── label_viz.png │ │ │ ├── main_window.png │ │ │ ├── pose_overlay.png │ │ │ ├── selecting_frames.png │ │ │ ├── stacked_timeline.png │ │ │ ├── timeline_menu.png │ │ │ └── track_overlay.png │ │ │ └── user_guide.md │ └── icon.png │ ├── scripts │ ├── __init__.py │ ├── classify.py │ ├── convert_parquet.py │ ├── generate_features.py │ ├── gui_entrypoint.py │ ├── initialize_project.py │ └── stats.py │ ├── types │ ├── __init__.py │ ├── classifier_types.py │ └── units.py │ ├── ui │ ├── __init__.py │ ├── about_dialog.py │ ├── archive_behavior_dialog.py │ ├── central_widget.py │ ├── classification_thread.py │ ├── colors.py │ ├── k_fold_slider_widget.py │ ├── label_count_widget.py │ ├── license_dialog.py │ ├── main_control_widget.py │ ├── main_window.py │ ├── player_widget │ │ ├── __init__.py │ │ ├── frame_widget.py │ │ ├── player_thread.py │ │ └── player_widget.py │ ├── project_loader_thread.py │ ├── stacked_timeline_widget │ │ ├── __init__.py │ │ ├── frame_labels_widget.py │ │ ├── label_overview_widget │ │ │ ├── __init__.py │ │ │ ├── label_overview_widget.py │ │ │ ├── manual_label_widget.py │ │ │ ├── predicted_label_widget.py │ │ │ ├── prediction_overview_widget.py │ │ │ ├── timeline_label_widget.py │ │ │ └── timeline_prediction_widget.py │ │ └── stacked_timeline_widget.py │ ├── training_thread.py │ ├── user_guide_viewer_widget │ │ ├── __init__.py │ │ └── user_guide_dialog.py │ └── video_list_widget.py │ ├── utils │ ├── __init__.py │ ├── sampleposeintervals.py │ └── utilities.py │ ├── version │ └── __init__.py │ └── video_reader │ ├── __init__.py │ ├── frame_annotation.py │ ├── utilities.py │ └── video_reader.py ├── tests ├── __init__.py ├── data │ ├── identity_with_no_data_pose_est_v3.h5.gz │ ├── readme.txt │ ├── sample_pose_est_v2.h5.gz │ ├── sample_pose_est_v3.h5.gz │ ├── sample_pose_est_v4.h5.gz │ ├── sample_pose_est_v5.h5.gz │ └── sample_pose_est_v6.h5.gz ├── feature_modules │ ├── __init__.py │ ├── base.py │ ├── test_corner_features.py │ ├── test_food_hopper.py │ └── test_lixit_distance.py ├── feature_tests │ ├── __init__.py │ ├── seg_test_utils.py │ ├── test_hu_moments.py │ ├── test_moments.py │ └── test_temporal_features.py.disabled ├── project │ ├── __init__.py │ ├── test_prediction_manager.py │ ├── test_project.py │ ├── test_project_path.py │ ├── test_project_util.py │ ├── test_settings_manager.py │ └── test_video_manager.py ├── test_base_features │ └── test_segmentation.py ├── test_pose_ancillary.py ├── test_pose_file.py ├── test_social_features │ └── test_signal_processing.py ├── test_track_labels.py └── test_video_labels.py └── vm ├── behavior-classifier-vm-gui.def ├── behavior-classifier-vm.def ├── behavior-classify-batch.sh └── generate-features.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | dist 4 | 5 | .idea 6 | .vscode 7 | 8 | .coverage 9 | test-reports 10 | 11 | jabs.venv 12 | venv 13 | .venv 14 | 15 | .DS_Store 16 | *.sif 17 | 18 | temp 19 | data 20 | 21 | tests/testing_notebook.ipynb 22 | docs/notes.txt 23 | docs/work_history.md 24 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.11.10 # Use the latest tag from https://github.com/astral-sh/ruff-pre-commit/tags 4 | hooks: 5 | - id: ruff-check 6 | - id: ruff-format 7 | -------------------------------------------------------------------------------- /img/jabs_screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/img/jabs_screenshot.png -------------------------------------------------------------------------------- /launch_jabs.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | REM Save the current working directory 3 | set "initialDir=%CD%" 4 | 5 | REM Change to the script's directory 6 | cd /d "%~dp0" 7 | 8 | 9 | echo Starting JAX Animal Behavior System App... 10 | jabs.venv\Scripts\activate && jabs 11 | 12 | REM restore working directory 13 | cd /d "%initialDir%" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "jabs-behavior-classifier" 3 | version = "0.26.0" 4 | license = "Proprietary" 5 | repository = "https://github.com/KumarLabJax/JABS-behavior-classifier" 6 | description = "" 7 | authors = ["Glen Beane", "Brian Geuther", "Keith Sheppard"] 8 | readme = "README.md" 9 | packages = [ 10 | { include = "jabs", from = "src" }, 11 | ] 12 | 13 | [tool.poetry.scripts] 14 | jabs = "jabs.scripts:main" 15 | jabs-classify = "jabs.scripts.classify:main" 16 | jabs-init = "jabs.scripts.initialize_project:main" 17 | jabs-features = "jabs.scripts.generate_features:main" 18 | jabs-stats = "jabs.scripts.stats:main" 19 | jabs-convert-parquet = "jabs.scripts.convert_parquet:main" 20 | 21 | [tool.poetry.dependencies] 22 | python = ">=3.10,<3.14" 23 | h5py = "^3.10.0" 24 | markdown2 = "^2.5.1" 25 | numpy = "^2.0.0" 26 | opencv-python-headless = "^4.8.1.78" 27 | pandas = "^2.2.2" 28 | pyside6 = "^6.8.0,!=6.9.0" 29 | scikit-learn = "^1.5.0" 30 | shapely = "^2.0.1" 31 | tabulate = "^0.9.0" 32 | toml = "^0.10.2" 33 | xgboost = "1.7.6" 34 | pyarrow = "^20.0.0" 35 | argparse-formatter = "^1.4" 36 | 37 | [tool.poetry.group.dev.dependencies] 38 | matplotlib = "^3.9.3" 39 | pytest = "^8.3.4" 40 | ruff = "^0.11.5" 41 | pre-commit = "^4.2.0" 42 | 43 | [build-system] 44 | requires = ["poetry-core"] 45 | build-backend = "poetry.core.masonry.api" 46 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | lint.select = [ 2 | "E", # pycodestyle errors 3 | "F", # pyflakes 4 | "D", # pydocstyle 5 | "I", # isort 6 | "UP", # pyupgrade 7 | "B", # flake8-bugbear 8 | "C4", # flake8-comprehensions 9 | "SIM", # flake8-simplify 10 | "RUF", # Ruff-specific rules 11 | ] 12 | lint.ignore = [ 13 | "D203", # one-blank-line-before-class (conflicts with D211) 14 | "D212", # multi-line-summary-first-line (conflicts with D213) 15 | "D107", # missing docstring in __init__ 16 | "D105", # missing docstring in magic method 17 | "D100", # missing module docstring (optional for smaller scripts) 18 | "E501", # line too long (handled by formatter) 19 | "D403", # ignore doc capitalization 20 | "D415", # ignore doc punctuation 21 | ] 22 | 23 | target-version = "py310" 24 | line-length = 99 25 | 26 | [lint.pydocstyle] 27 | convention = "google" 28 | 29 | [lint.per-file-ignores] 30 | "__init__.py" = ["F401"] # Unused imports in __init__ files 31 | -------------------------------------------------------------------------------- /setup_windows.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | setlocal EnableDelayedExpansion 3 | 4 | REM Save the current working directory 5 | set "initialDir=%CD%" 6 | 7 | REM Change to the script's directory 8 | cd /d "%~dp0" 9 | 10 | REM discontinue support for this script once we are delivering wheels for installing JABS with pip 11 | 12 | REM Check for skip version check argument 13 | set "SKIP_VERSION_CHECK=0" 14 | for %%i in (%*) do ( 15 | if "%%~i"=="--skip-version-check" ( 16 | set "SKIP_VERSION_CHECK=1" 17 | ) 18 | ) 19 | 20 | if "!SKIP_VERSION_CHECK!"=="0" ( 21 | REM Check for Python Installation 22 | echo Checking for python 23 | set "VER=" 24 | for /f "usebackq tokens=*" %%i in (`python --version 2^>nul`) do set "VER=%%i" 25 | 26 | if "!VER!"=="" ( 27 | echo Python is not installed or not in PATH. 28 | REM restore working directory 29 | cd /d "%initialDir%" 30 | exit /b 1 31 | ) 32 | 33 | set OK=0 34 | 35 | REM Supported versions of Python 36 | if "!VER:~7,4!"=="3.10" set OK=1 37 | if "!VER:~7,4!"=="3.11" set OK=1 38 | if "!VER:~7,4!"=="3.12" set OK=1 39 | if "!VER:~7,4!"=="3.13" set OK=1 40 | 41 | if "!OK!"=="1" ( 42 | echo Found !VER! 43 | ) else ( 44 | echo Compatible Python not found 45 | REM restore working directory 46 | cd /d "%initialDir%" 47 | exit /b 1 48 | ) 49 | ) else ( 50 | echo Skipping Python version check 51 | ) 52 | 53 | echo Setting up Python Virtualenv... 54 | python -m venv jabs.venv 55 | call jabs.venv\Scripts\activate.bat && pip install . 56 | 57 | REM restore working directory 58 | cd /d "%initialDir%" 59 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | """this exists for pytest to be able to import from src.jabs""" 2 | -------------------------------------------------------------------------------- /src/jabs/__init__.py: -------------------------------------------------------------------------------- 1 | """JABS Behavior Classifier""" 2 | -------------------------------------------------------------------------------- /src/jabs/__main__.py: -------------------------------------------------------------------------------- 1 | from .scripts import main 2 | 3 | if __name__ == "__main__": 4 | main() 5 | -------------------------------------------------------------------------------- /src/jabs/classifier/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `jabs.classifier` package provides tools for training, evaluating, saving, and loading machine learning classifiers for behavioral data analysis. 3 | 4 | It includes the `Classifier` class, which supports multiple classification algorithms (such as Random Forest, 5 | Gradient Boosting, and XGBoost), utilities for feature management, data splitting, model evaluation, and serialization.` 6 | """ 7 | 8 | import pathlib 9 | 10 | from .classifier import Classifier 11 | 12 | HYPERPARAMETER_PATH = pathlib.Path(__file__).parent / "hyperparameters.json" 13 | 14 | __all__ = [ 15 | "Classifier", 16 | ] 17 | -------------------------------------------------------------------------------- /src/jabs/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """package for CLI utilities""" 2 | 3 | from .progress_bar import cli_progress_bar 4 | 5 | __all__ = [ 6 | "cli_progress_bar", 7 | ] 8 | -------------------------------------------------------------------------------- /src/jabs/cli/progress_bar.py: -------------------------------------------------------------------------------- 1 | def cli_progress_bar( 2 | completed: int, 3 | total_iterations: int, 4 | length=40, 5 | fill_char="█", 6 | padding_char="░", 7 | prefix="", 8 | suffix="", 9 | precision=1, 10 | complete_as_percent=True, 11 | ): 12 | """Call in a loop to create terminal progress bar. 13 | 14 | Note, the loop can't print any other output to stdout. 15 | 16 | Args: 17 | completed: number of completed iterations 18 | total_iterations: total number of iterations in calling loop 19 | length: length of bar in characters 20 | fill_char: filled bar character 21 | padding_char: unfilled bar character 22 | prefix: prefix string, printed before progress bar 23 | suffix: suffix string, printed after progress bar 24 | precision: number of decimals in percent complete 25 | complete_as_percent: if True, print percent complete, if False print "num_complete of total_num" 26 | 27 | Todo: 28 | - replace this with rich progress bar 29 | """ 30 | if len(fill_char) != 1: 31 | raise ValueError("Invalid fill character") 32 | if len(padding_char) != 1: 33 | raise ValueError("Invalid padding character") 34 | if precision < 1: 35 | raise ValueError("Invalid precision parameter") 36 | if length < 1: 37 | raise ValueError("Invalid length parameter") 38 | 39 | if total_iterations > 0: 40 | # create a string representation of the percent complete 41 | complete = f"{100 * (completed / float(total_iterations)):.{precision}f}" 42 | 43 | # calculate the length of the filled portion of the progress bar 44 | filled_length = int(length * completed // total_iterations) 45 | else: 46 | complete = 0 47 | filled_length = 0 48 | 49 | # create a string combining filled and unfilled portions 50 | bar = fill_char * filled_length + padding_char * (length - filled_length) 51 | 52 | if complete_as_percent: 53 | complete = f"{complete}%" 54 | else: 55 | width = len(str(total_iterations)) 56 | complete = f"{completed:{width}} of {total_iterations}" 57 | 58 | # print progress bar, overwriting the current line 59 | print(f"\r{prefix}{bar} {complete} {suffix}", end="\r") 60 | 61 | # print newline once the progress bar is filled 62 | if completed == total_iterations: 63 | print() 64 | -------------------------------------------------------------------------------- /src/jabs/constants.py: -------------------------------------------------------------------------------- 1 | ORG_NAME = "JAX" 2 | APP_NAME = "JABS" 3 | APP_NAME_LONG = f"{ORG_NAME} Animal Behavior System" 4 | 5 | # maximum number of recent projects to show in the File->Recent Projects menu 6 | RECENT_PROJECTS_MAX = 5 7 | 8 | # some defaults for compressing hdf5 output 9 | COMPRESSION = "gzip" 10 | COMPRESSION_OPTS_DEFAULT = 6 11 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | """The `feature_extraction` package provides modules and utilities for extracting behavioral and pose-based features from pose estimation data. 2 | 3 | It includes: 4 | - Core feature extraction classes for computing per-frame and windowed features. 5 | - Base features such as joint angles, angular velocities, centroid velocities, pairwise distances, and keypoint speeds. 6 | - Feature grouping and management utilities. 7 | - Versioning and configuration for feature extraction workflows. 8 | 9 | This package serves as the foundation for higher-level behavioral analysis and downstream processing of pose data. 10 | """ 11 | 12 | from .features import FEATURE_VERSION, IdentityFeatures 13 | 14 | DEFAULT_WINDOW_SIZE = 5 15 | 16 | __all__ = [ 17 | "DEFAULT_WINDOW_SIZE", 18 | "FEATURE_VERSION", 19 | "IdentityFeatures", 20 | ] 21 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/angle_index.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | from jabs.pose_estimation import PoseEstimation 4 | 5 | 6 | class AngleIndex(enum.IntEnum): 7 | """enum defining the indexes of the angle features""" 8 | 9 | NOSE_BASE_NECK_RIGHT_FRONT_PAW = 0 10 | NOSE_BASE_NECK_LEFT_FRONT_PAW = 1 11 | RIGHT_FRONT_PAW_BASE_NECK_CENTER_SPINE = 2 12 | LEFT_FRONT_PAW_BASE_NECK_CENTER_SPINE = 3 13 | BASE_NECK_CENTER_SPINE_BASE_TAIL = 4 14 | RIGHT_REAR_PAW_BASE_TAIL_CENTER_SPINE = 5 15 | LEFT_REAR_PAW_BASE_TAIL_CENTER_SPINE = 6 16 | RIGHT_REAR_PAW_BASE_TAIL_MID_TAIL = 7 17 | LEFT_REAR_PAW_BASE_TAIL_MID_TAIL = 8 18 | CENTER_SPINE_BASE_TAIL_MID_TAIL = 9 19 | BASE_TAIL_MID_TAIL_TIP_TAIL = 10 20 | 21 | @staticmethod 22 | def get_angle_name(i: "AngleIndex"): 23 | """map angle index to a string name""" 24 | strings = { 25 | AngleIndex.NOSE_BASE_NECK_RIGHT_FRONT_PAW: "NOSE-BASE_NECK-RIGHT_FRONT_PAW", 26 | AngleIndex.NOSE_BASE_NECK_LEFT_FRONT_PAW: "NOSE-BASE_NECK-LEFT_FRONT_PAW", 27 | AngleIndex.RIGHT_FRONT_PAW_BASE_NECK_CENTER_SPINE: "RIGHT_FRONT_PAW-BASE_NECK-CENTER_SPINE", 28 | AngleIndex.LEFT_FRONT_PAW_BASE_NECK_CENTER_SPINE: "LEFT_FRONT_PAW-BASE_NECK-CENTER_SPINE", 29 | AngleIndex.BASE_NECK_CENTER_SPINE_BASE_TAIL: "BASE_NECK-CENTER_SPINE-BASE_TAIL", 30 | AngleIndex.RIGHT_REAR_PAW_BASE_TAIL_CENTER_SPINE: "RIGHT_REAR_PAW-BASE_TAIL-CENTER_SPINE", 31 | AngleIndex.LEFT_REAR_PAW_BASE_TAIL_CENTER_SPINE: "LEFT_REAR_PAW-BASE_TAIL-CENTER_SPINE", 32 | AngleIndex.RIGHT_REAR_PAW_BASE_TAIL_MID_TAIL: "RIGHT_REAR_PAW-BASE_TAIL-MID_TAIL", 33 | AngleIndex.LEFT_REAR_PAW_BASE_TAIL_MID_TAIL: "LEFT_REAR_PAW-BASE_TAIL-MID_TAIL", 34 | AngleIndex.CENTER_SPINE_BASE_TAIL_MID_TAIL: "CENTER_SPINE-BASE_TAIL-MID_TAIL", 35 | AngleIndex.BASE_TAIL_MID_TAIL_TIP_TAIL: "BASE_TAIL-MID_TAIL-TIP_TAIL", 36 | } 37 | return strings[i] 38 | 39 | @staticmethod 40 | def get_angle_indices(i: "AngleIndex"): 41 | """get the keypoint indices for a given angle index""" 42 | angles = { 43 | AngleIndex.NOSE_BASE_NECK_RIGHT_FRONT_PAW: [ 44 | PoseEstimation.KeypointIndex.NOSE, 45 | PoseEstimation.KeypointIndex.BASE_NECK, 46 | PoseEstimation.KeypointIndex.RIGHT_FRONT_PAW, 47 | ], 48 | AngleIndex.NOSE_BASE_NECK_LEFT_FRONT_PAW: [ 49 | PoseEstimation.KeypointIndex.NOSE, 50 | PoseEstimation.KeypointIndex.BASE_NECK, 51 | PoseEstimation.KeypointIndex.LEFT_FRONT_PAW, 52 | ], 53 | AngleIndex.RIGHT_FRONT_PAW_BASE_NECK_CENTER_SPINE: [ 54 | PoseEstimation.KeypointIndex.RIGHT_FRONT_PAW, 55 | PoseEstimation.KeypointIndex.BASE_NECK, 56 | PoseEstimation.KeypointIndex.CENTER_SPINE, 57 | ], 58 | AngleIndex.LEFT_FRONT_PAW_BASE_NECK_CENTER_SPINE: [ 59 | PoseEstimation.KeypointIndex.LEFT_FRONT_PAW, 60 | PoseEstimation.KeypointIndex.BASE_NECK, 61 | PoseEstimation.KeypointIndex.CENTER_SPINE, 62 | ], 63 | AngleIndex.BASE_NECK_CENTER_SPINE_BASE_TAIL: [ 64 | PoseEstimation.KeypointIndex.BASE_NECK, 65 | PoseEstimation.KeypointIndex.CENTER_SPINE, 66 | PoseEstimation.KeypointIndex.BASE_TAIL, 67 | ], 68 | AngleIndex.RIGHT_REAR_PAW_BASE_TAIL_CENTER_SPINE: [ 69 | PoseEstimation.KeypointIndex.RIGHT_REAR_PAW, 70 | PoseEstimation.KeypointIndex.BASE_TAIL, 71 | PoseEstimation.KeypointIndex.CENTER_SPINE, 72 | ], 73 | AngleIndex.LEFT_REAR_PAW_BASE_TAIL_CENTER_SPINE: [ 74 | PoseEstimation.KeypointIndex.LEFT_REAR_PAW, 75 | PoseEstimation.KeypointIndex.BASE_TAIL, 76 | PoseEstimation.KeypointIndex.CENTER_SPINE, 77 | ], 78 | AngleIndex.RIGHT_REAR_PAW_BASE_TAIL_MID_TAIL: [ 79 | PoseEstimation.KeypointIndex.RIGHT_REAR_PAW, 80 | PoseEstimation.KeypointIndex.BASE_TAIL, 81 | PoseEstimation.KeypointIndex.MID_TAIL, 82 | ], 83 | AngleIndex.LEFT_REAR_PAW_BASE_TAIL_MID_TAIL: [ 84 | PoseEstimation.KeypointIndex.LEFT_REAR_PAW, 85 | PoseEstimation.KeypointIndex.BASE_TAIL, 86 | PoseEstimation.KeypointIndex.MID_TAIL, 87 | ], 88 | AngleIndex.CENTER_SPINE_BASE_TAIL_MID_TAIL: [ 89 | PoseEstimation.KeypointIndex.CENTER_SPINE, 90 | PoseEstimation.KeypointIndex.BASE_TAIL, 91 | PoseEstimation.KeypointIndex.MID_TAIL, 92 | ], 93 | AngleIndex.BASE_TAIL_MID_TAIL_TIP_TAIL: [ 94 | PoseEstimation.KeypointIndex.BASE_TAIL, 95 | PoseEstimation.KeypointIndex.MID_TAIL, 96 | PoseEstimation.KeypointIndex.TIP_TAIL, 97 | ], 98 | } 99 | return angles[i] 100 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/__init__.py: -------------------------------------------------------------------------------- 1 | """The `base_features` package provides core feature extraction modules. 2 | 3 | It includes classes for computing joint angles, angular velocities, centroid velocities, pairwise distances, 4 | keypoint speeds, and velocity directions. These features serve as fundamental building blocks for higher-level 5 | behavioral analysis. 6 | 7 | Modules: 8 | - angles: Computes joint angles. 9 | - angular_velocity: Calculates angular velocities of joints. 10 | - base_group: Manages base feature modules. 11 | - centroid_velocity: Computes centroid velocity magnitude and direction. 12 | - pairwise_distances: Calculates pairwise distances between keypoints. 13 | - point_speeds: Computes per-frame speeds of keypoints. 14 | - point_velocities: Computes velocity directions for keypoints. 15 | """ 16 | 17 | from .angles import Angles 18 | from .angular_velocity import AngularVelocity 19 | from .base_group import BaseFeatureGroup 20 | from .centroid_velocity import CentroidVelocityDir, CentroidVelocityMag 21 | from .pairwise_distances import PairwisePointDistances 22 | from .point_speeds import PointSpeeds 23 | from .point_velocities import PointVelocityDirs 24 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/angles.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import scipy.stats 5 | 6 | from jabs.feature_extraction.angle_index import AngleIndex 7 | from jabs.feature_extraction.feature_base_class import Feature 8 | from jabs.pose_estimation import PoseEstimation 9 | 10 | 11 | class Angles(Feature): 12 | """this module computes joint angles the result is a dict of features of length #frames rows""" 13 | 14 | _name = "angles" 15 | 16 | # override for circular values 17 | _window_operations: typing.ClassVar[dict[str, typing.Callable]] = { 18 | "mean": lambda x: scipy.stats.circmean(x, high=360, nan_policy="omit"), 19 | "std_dev": lambda x: scipy.stats.circstd(x, high=360, nan_policy="omit"), 20 | } 21 | 22 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 23 | super().__init__(poses, pixel_scale) 24 | self._num_angles = len(AngleIndex) 25 | 26 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 27 | """compute the value of the per frame features for a specific identity""" 28 | values = {} 29 | 30 | poses, _ = self._poses.get_identity_poses(identity, self._pixel_scale) 31 | 32 | for named_angle in AngleIndex: 33 | angle_keypoints = AngleIndex.get_angle_indices(named_angle) 34 | values[f"angle {AngleIndex.get_angle_name(named_angle)}"] = ( 35 | self._compute_angles( 36 | poses[:, angle_keypoints[0]], 37 | poses[:, angle_keypoints[1]], 38 | poses[:, angle_keypoints[2]], 39 | ) 40 | ) 41 | return values 42 | 43 | def window(self, identity: int, window_size: int, per_frame_values: dict) -> dict: 44 | """compute window feature values. 45 | 46 | overrides the base class method to handle circular values. 47 | 48 | Args: 49 | identity (int): subject identity 50 | window_size (int): window size NOTE: (actual window size is 2 * window_size + 1) 51 | per_frame_values (dict): per frame values for this identity 52 | """ 53 | return self._window_circular(identity, window_size, per_frame_values) 54 | 55 | @staticmethod 56 | def _compute_angles(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray: 57 | """compute angles for a set of points 58 | 59 | Args: 60 | a: array of point coordinates 61 | b: array of vertex point coordinates 62 | c: array of point coordinates 63 | 64 | Returns: 65 | array containing angles, in degrees, formed from the lines ab and ba for each row in a, b, and c with range [0, 360) 66 | """ 67 | angles = np.degrees( 68 | np.arctan2(c[:, 1] - b[:, 1], c[:, 0] - b[:, 0]) 69 | - np.arctan2(a[:, 1] - b[:, 1], a[:, 0] - b[:, 0]) 70 | ) 71 | return angles % 360 72 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/angular_velocity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from jabs.feature_extraction.feature_base_class import Feature 4 | from jabs.pose_estimation import PoseEstimation 5 | 6 | 7 | class AngularVelocity(Feature): 8 | """compute angular velocity of animal bearing""" 9 | 10 | _name = "angular_velocity" 11 | 12 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 13 | super().__init__(poses, pixel_scale) 14 | 15 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 16 | """compute the value of the per frame features for a specific identity 17 | 18 | Args: 19 | identity: identity to compute features for 20 | 21 | Returns: 22 | dict with feature values 23 | """ 24 | fps = self._poses.fps 25 | 26 | bearings = self._poses.compute_all_bearings(identity) 27 | velocities = np.full(bearings.shape, np.nan, bearings.dtype) 28 | 29 | for i in range(len(bearings) - 1): 30 | angle1 = bearings[i] 31 | angle2 = bearings[i + 1] 32 | 33 | if np.isnan(angle1) or np.isnan(angle2): 34 | continue 35 | 36 | angle1 = angle1 % 360 37 | if angle1 < 0: 38 | angle1 += 360 39 | 40 | angle2 = angle2 % 360 41 | if angle2 < 0: 42 | angle2 += 360 43 | 44 | diff1 = angle2 - angle1 45 | abs_diff1 = abs(diff1) 46 | diff2 = (360 + angle2) - angle1 47 | abs_diff2 = abs(diff2) 48 | diff3 = angle2 - (360 + angle1) 49 | abs_diff3 = abs(diff3) 50 | 51 | if abs_diff1 <= abs_diff2 and abs_diff1 <= abs_diff3: 52 | velocities[i] = diff1 53 | elif abs_diff2 <= abs_diff3: 54 | velocities[i] = diff2 55 | else: 56 | velocities[i] = diff3 57 | velocities = velocities * fps 58 | 59 | return {"angular_velocity": velocities} 60 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/base_group.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from jabs.feature_extraction.feature_group_base_class import FeatureGroup 4 | 5 | from ..feature_base_class import Feature 6 | from .angles import Angles 7 | from .angular_velocity import AngularVelocity 8 | from .centroid_velocity import CentroidVelocityDir, CentroidVelocityMag 9 | from .pairwise_distances import PairwisePointDistances 10 | from .point_speeds import PointSpeeds 11 | from .point_velocities import PointVelocityDirs 12 | 13 | 14 | class BaseFeatureGroup(FeatureGroup): 15 | """Base class for feature extraction groups.""" 16 | 17 | _name = "base" 18 | 19 | # build a dictionary that maps a feature name to the class that 20 | # implements it 21 | _features: typing.ClassVar[dict[str, Feature]] = { 22 | PairwisePointDistances.name(): PairwisePointDistances, 23 | Angles.name(): Angles, 24 | AngularVelocity.name(): AngularVelocity, 25 | PointSpeeds.name(): PointSpeeds, 26 | PointVelocityDirs.name(): PointVelocityDirs, 27 | CentroidVelocityDir.name(): CentroidVelocityDir, 28 | CentroidVelocityMag.name(): CentroidVelocityMag, 29 | } 30 | 31 | def _init_feature_mods(self, identity: int): 32 | """initialize all the feature modules specified in the current config 33 | 34 | Args: 35 | identity: unused, specified by abstract base class 36 | 37 | Returns: 38 | dictionary of initialized feature modules for this group 39 | """ 40 | return { 41 | feature: self._features[feature](self._poses, self._pixel_scale) 42 | for feature in self._enabled_features 43 | } 44 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/centroid_velocity.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import scipy.stats 5 | 6 | from jabs.feature_extraction.feature_base_class import Feature 7 | from jabs.pose_estimation import PoseEstimation 8 | 9 | # TODO: merge CentroidVelocityMag and CentroidVelocityDir into a single feature 10 | # with a 2D numpy array of values 11 | # these are currently separate features in the features file, so we keep them 12 | # separate here for ease of implementation, but this results in duplicated 13 | # work computing each feature. Fix at next update to feature h5 file format. 14 | 15 | 16 | class CentroidVelocityDir(Feature): 17 | """feature for the direction of the center of mass velocity""" 18 | 19 | _name = "centroid_velocity_dir" 20 | 21 | # override for circular values 22 | _window_operations: typing.ClassVar[dict[str, typing.Callable]] = { 23 | "mean": lambda x: scipy.stats.circmean( 24 | x, low=-180, high=180, nan_policy="omit" 25 | ), 26 | "std_dev": lambda x: scipy.stats.circstd( 27 | x, low=-180, high=180, nan_policy="omit" 28 | ), 29 | } 30 | 31 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 32 | super().__init__(poses, pixel_scale) 33 | 34 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 35 | """compute the value of the per frame features for a specific identity""" 36 | bearings = self._poses.compute_all_bearings(identity) 37 | frame_valid = self._poses.identity_mask(identity) 38 | 39 | # compute the velocity of the center of mass. 40 | # first, grab convex hulls for this identity 41 | convex_hulls = self._poses.get_identity_convex_hulls(identity) 42 | 43 | # get an array of the indexes of valid frames only 44 | indexes = np.arange(self._poses.num_frames)[frame_valid == 1] 45 | 46 | # get centroids for all frames where this identity is present 47 | centroid_centers = np.full( 48 | [self._poses.num_frames, 2], np.nan, dtype=np.float32 49 | ) 50 | for i in indexes: 51 | centroid_centers[i, :] = np.asarray(convex_hulls[i].centroid.xy).squeeze() 52 | 53 | v = np.gradient(centroid_centers, axis=0) 54 | 55 | # compute direction of velocities 56 | d = np.degrees(np.arctan2(v[:, 1], v[:, 0])) 57 | 58 | # subtract animal bearing from orientation 59 | # convert angle to range -180 to 180 60 | values = (((d - bearings) + 180) % 360) - 180 61 | 62 | return {"centroid_velocity_dir": values} 63 | 64 | def window(self, identity: int, window_size: int, per_frame_values: dict) -> dict: 65 | """compute window feature values for the centroid velocity direction 66 | 67 | Overrides base class to use special method for computing window features with circular values 68 | 69 | Args: 70 | identity (int): subject identity 71 | window_size (int): window size NOTE: (actual window size is 2 * 72 | window_size + 1) 73 | per_frame_values (dict[str, np.ndarray]): dictionary of per frame values for this identity 74 | 75 | Returns: 76 | dict: dictionary where keys are window feature names and values are the computed window features at each 77 | frame for the given identity 78 | 79 | """ 80 | return self._window_circular(identity, window_size, per_frame_values) 81 | 82 | 83 | class CentroidVelocityMag(Feature): 84 | """feature for the magnitude of the center of mass velocity""" 85 | 86 | _name = "centroid_velocity_mag" 87 | 88 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 89 | super().__init__(poses, pixel_scale) 90 | 91 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 92 | """compute the value of the per frame features for a specific identity 93 | 94 | Args: 95 | identity: identity to compute features for 96 | 97 | Returns: 98 | np.ndarray with feature values 99 | """ 100 | values = np.full(self._poses.num_frames, np.nan, dtype=np.float32) 101 | fps = self._poses.fps 102 | frame_valid = self._poses.identity_mask(identity) 103 | 104 | # compute the velocity of the center of mass. 105 | # first, grab convex hulls for this identity 106 | convex_hulls = self._poses.get_identity_convex_hulls(identity) 107 | 108 | # get an array of the indexes of valid frames only 109 | indexes = np.arange(self._poses.num_frames)[frame_valid == 1] 110 | 111 | # get centroids for all frames where this identity is present 112 | centroid_centers = np.full( 113 | [self._poses.num_frames, 2], np.nan, dtype=np.float32 114 | ) 115 | for i in indexes: 116 | centroid_centers[i, :] = np.asarray(convex_hulls[i].centroid.xy).squeeze() 117 | 118 | # get change over frames 119 | v = np.gradient(centroid_centers, axis=0) 120 | values = np.linalg.norm(v, axis=-1) * fps * self._pixel_scale 121 | 122 | return {"centroid_velocity_mag": values} 123 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/pairwise_distances.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from jabs.feature_extraction.feature_base_class import Feature 4 | from jabs.pose_estimation import PoseEstimation 5 | 6 | 7 | class PairwisePointDistances(Feature): 8 | """Feature extraction class for computing pairwise Euclidean distances between all keypoints per frame. 9 | 10 | This class calculates the Euclidean distance between every unique pair of keypoints, returning a dictionary 11 | mapping each keypoint pair to an array of per-frame distances. 12 | """ 13 | 14 | _name = "pairwise_distances" 15 | 16 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 17 | super().__init__(poses, pixel_scale) 18 | 19 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 20 | """compute the value of the per frame features for a specific identity 21 | 22 | Args: 23 | identity: identity to compute features for 24 | 25 | Returns: 26 | dict with feature values 27 | """ 28 | points, _ = self._poses.get_identity_poses(identity, self._pixel_scale) 29 | 30 | values = {} 31 | point_names = [p.name for p in PoseEstimation.KeypointIndex] 32 | 33 | for i in range(0, len(point_names)): 34 | p1_name = point_names[i] 35 | for j in range(i + 1, len(point_names)): 36 | p2_name = point_names[j] 37 | # compute euclidean distance between ith and jth points 38 | euclidean_dist = np.sqrt( 39 | np.square(points[:, i, 0] - points[:, j, 0]) 40 | + np.square(points[:, i, 1] - points[:, j, 1]) 41 | ) 42 | values[f"{p1_name}-{p2_name}"] = euclidean_dist 43 | 44 | return values 45 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/point_speeds.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from jabs.feature_extraction.feature_base_class import Feature 4 | from jabs.pose_estimation import PoseEstimation 5 | 6 | 7 | class PointSpeeds(Feature): 8 | """Feature extraction class for computing the speed of each keypoint per frame. 9 | 10 | This class calculates the instantaneous speed of each keypoint by computing the Euclidean norm of the 11 | frame-to-frame displacement, scaled by the video frame rate. The resulting speeds are provided as a dictionary 12 | mapping keypoint names to per-frame speed arrays. 13 | """ 14 | 15 | _name = "point_speeds" 16 | 17 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 18 | super().__init__(poses, pixel_scale) 19 | 20 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 21 | """compute the value of the per frame features for a specific identity 22 | 23 | Args: 24 | identity: identity to compute features for 25 | 26 | Returns: 27 | dict with feature values 28 | """ 29 | fps = self._poses.fps 30 | poses, point_masks = self._poses.get_identity_poses(identity, self._pixel_scale) 31 | 32 | speeds = {} 33 | 34 | # calculate velocities for each point 35 | xy_deltas = np.gradient(poses, axis=0) 36 | point_velocities = np.linalg.norm(xy_deltas, axis=-1) * fps 37 | 38 | for keypoint in PoseEstimation.KeypointIndex: 39 | speeds[f"{keypoint.name} speed"] = point_velocities[:, keypoint.value] 40 | 41 | return speeds 42 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/base_features/point_velocities.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | import numpy as np 5 | import scipy.stats 6 | 7 | from jabs.feature_extraction.feature_base_class import Feature 8 | from jabs.pose_estimation import PoseEstimation 9 | 10 | # TODO: merge this with point_speeds to reduce compute 11 | # since they both use keypoint gradients 12 | 13 | 14 | class PointVelocityDirs(Feature, abc.ABC): 15 | """feature for the direction of the point velocity""" 16 | 17 | # subclass must override this 18 | _name = "point_velocity_dirs" 19 | _point_index = None 20 | 21 | # override for circular values 22 | _window_operations: typing.ClassVar[dict[str, typing.Callable]] = { 23 | "mean": lambda x: scipy.stats.circmean( 24 | x, low=-180, high=180, nan_policy="omit" 25 | ), 26 | "std_dev": lambda x: scipy.stats.circstd( 27 | x, low=-180, high=180, nan_policy="omit" 28 | ), 29 | } 30 | 31 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 32 | super().__init__(poses, pixel_scale) 33 | 34 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 35 | """compute per-frame feature values 36 | 37 | Args: 38 | identity (int): subject identity 39 | 40 | Returns: 41 | dict[str, np.ndarray]: dictionary of per frame values for this identity 42 | """ 43 | poses, point_masks = self._poses.get_identity_poses(identity, self._pixel_scale) 44 | 45 | bearings = self._poses.compute_all_bearings(identity) 46 | 47 | directions = {} 48 | xy_deltas = np.gradient(poses, axis=0) 49 | angles = np.degrees(np.arctan2(xy_deltas[:, :, 1], xy_deltas[:, :, 0])) 50 | 51 | for keypoint in PoseEstimation.KeypointIndex: 52 | directions[f"{keypoint.name} velocity direction"] = ( 53 | (angles[:, keypoint.value] - bearings + 360) % 360 54 | ) - 180 55 | 56 | return directions 57 | 58 | def window(self, identity: int, window_size: int, per_frame_values: dict) -> dict: 59 | """compute window feature values. 60 | 61 | Args: 62 | identity (int): subject identity 63 | window_size (int): window size NOTE: (actual window size is 2 * 64 | window_size + 1) 65 | per_frame_values (dict[str, np.ndarray]): dictionary of per frame values for this identity 66 | 67 | need to override to use special method for computing window features with circular values 68 | """ 69 | return self._window_circular(identity, window_size, per_frame_values) 70 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/feature_group_base_class.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import typing 3 | 4 | from jabs.pose_estimation import PoseEstimation 5 | 6 | from .feature_base_class import Feature 7 | 8 | 9 | class FeatureGroup(abc.ABC): 10 | """Abstract base class for groups of related feature extraction modules. 11 | 12 | This class manages a collection of feature modules, providing methods to compute per-frame and windowed features 13 | for a given subject identity. It also handles enabling/disabling features, querying supported features based on 14 | pose version and static objects, and provides class-level metadata. 15 | 16 | Methods: 17 | per_frame(identity): Compute per-frame features for a specific identity. 18 | window(identity, window_size, per_frame_values): Compute windowed features for a specific identity. 19 | enabled_features: Property returning the names of currently enabled features. 20 | module_names(): Class method returning all feature names in this group. 21 | name(): Class method returning the group name. 22 | get_supported_feature_modules(pose_version, static_objects, **kwargs): Class method returning supported features. 23 | _init_feature_mods(identity): Abstract method to initialize feature modules for an identity. 24 | """ 25 | 26 | # to be defined in subclass 27 | _features: typing.ClassVar[dict[str, Feature]] = {} 28 | _name = None 29 | 30 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 31 | super().__init__() 32 | self._enabled_features = [] 33 | self._poses = poses 34 | self._pixel_scale = pixel_scale 35 | if self._name is None: 36 | raise NotImplementedError("Base class must override _name class member") 37 | 38 | # _features above defines all features that are part of this group, 39 | # but self._enabled_features lists which features are currently enabled. 40 | # by default, all features are turned on 41 | self._enabled_features = list(self._features.keys()) 42 | 43 | def per_frame(self, identity: int) -> dict: 44 | """compute the value of the per frame features for a specific identity 45 | 46 | Args: 47 | identity: identity to compute features for 48 | 49 | Returns: 50 | dict where each key is the name of a feature module included in this FeatureGroup 51 | """ 52 | feature_modules = self._init_feature_mods(identity) 53 | return {name: mod.per_frame(identity) for name, mod in feature_modules.items()} 54 | 55 | def window(self, identity: int, window_size: int, per_frame_values: dict) -> dict: 56 | """compute window feature values for a given identities per frame values 57 | 58 | Args: 59 | identity: subject identity 60 | window_size: window size NOTE: (actual window size is 2 * 61 | window_size + 1) 62 | per_frame_values: per frame feature values 63 | 64 | Returns: 65 | dictionary where keys are feature module names that are part 66 | of this FeatureGroup. The value for each element is the window feature 67 | dict returned by that module. 68 | """ 69 | feature_modules = self._init_feature_mods(identity) 70 | return { 71 | name: mod.window(identity, window_size, per_frame_values[name]) 72 | for name, mod in feature_modules.items() 73 | } 74 | 75 | @property 76 | def enabled_features(self): 77 | """return the names of the features that are currently enabled in this group""" 78 | return self._enabled_features 79 | 80 | @abc.abstractmethod 81 | def _init_feature_mods(self, identity: int) -> dict: 82 | pass 83 | 84 | @classmethod 85 | def module_names(cls): 86 | """return the names of the features in this group""" 87 | return list(cls._features.keys()) 88 | 89 | @classmethod 90 | def name(cls): 91 | """return the name of this feature group""" 92 | return cls._name 93 | 94 | @classmethod 95 | def get_supported_feature_modules( 96 | cls, 97 | pose_version: int, 98 | static_objects: set[str], 99 | **kwargs, 100 | ) -> list[str]: 101 | """Get the features supported by this group based on the pose version, static objects, and optional additional attributes 102 | 103 | Args: 104 | pose_version (int): version of the pose estimation file 105 | static_objects (set[str]): set of static objects available to the project 106 | **kwargs: additional keyword arguments that may be used by specific feature classes 107 | """ 108 | features = [] 109 | for feature_name, feature_class in cls._features.items(): 110 | if feature_class.is_supported(pose_version, static_objects, **kwargs): 111 | features.append(feature_name) 112 | return features 113 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/landmark_features/__init__.py: -------------------------------------------------------------------------------- 1 | """package for feature extraction based on static objects located in the enclosure.""" 2 | 3 | from .landmark_group import LandmarkFeatureGroup 4 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/landmark_features/food_hopper.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from jabs.feature_extraction.feature_base_class import Feature 7 | from jabs.pose_estimation import PoseEstimation 8 | 9 | _EXCLUDED_POINTS = [ 10 | PoseEstimation.KeypointIndex.MID_TAIL, 11 | PoseEstimation.KeypointIndex.TIP_TAIL, 12 | ] 13 | 14 | 15 | class FoodHopper(Feature): 16 | """Feature extraction class for computing distances from mouse keypoints to the food hopper polygon. 17 | 18 | For each frame and identity, this class calculates the signed distance from each keypoint (excluding mid tail 19 | and tip tail) to the polygon defined by the food hopper's keypoints. The distance is positive if the keypoint is 20 | inside the polygon, negative if outside, and zero if on the edge. Results are provided as per-frame arrays for 21 | each relevant keypoint. 22 | """ 23 | 24 | _name = "food_hopper" 25 | _min_pose = 5 26 | _static_objects: typing.ClassVar[list[str]] = ["food_hopper"] 27 | 28 | def per_frame(self, identity: int) -> dict: 29 | """get the per frame feature values for the food hopper landmark 30 | 31 | Args: 32 | identity: identity to get feature values for 33 | 34 | Returns: 35 | numpy ndarray of values with shape (nframes, 10) 36 | 37 | for each frame, the 10 feature values are the signed distance from the key point 38 | to the polygon defined by the food hopper key points (10 points 39 | because the mid tail and tail tip are excluded) 40 | """ 41 | hopper = self._poses.static_objects["food_hopper"] 42 | if self._pixel_scale is not None: 43 | hopper = hopper * self._pixel_scale 44 | 45 | # change dtype to float32 for open cv 46 | hopper_pts = hopper.astype(np.float32) 47 | 48 | points, _ = self._poses.get_identity_poses(identity, self._pixel_scale) 49 | 50 | values = {} 51 | 52 | # for each keypoint (except mid tail and tail tip), compute signed distance (measureDist=True) 53 | # to the polygon defined by the food hopper key points. Distance is negative if the key point 54 | # is outside the polygon, positive if inside, zero if it is on the edge 55 | for key_point in PoseEstimation.KeypointIndex: 56 | # skip over the key points we don't care about 57 | if key_point in _EXCLUDED_POINTS: 58 | continue 59 | 60 | pts = points[:, key_point.value, :] 61 | 62 | distance = np.asarray( 63 | [cv2.pointPolygonTest(hopper_pts, (p[0], p[1]), True) for p in pts] 64 | ) 65 | distance[np.isnan(pts[:, 0])] = np.nan 66 | values[f"food hopper {key_point.name}"] = distance 67 | 68 | return values 69 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/segmentation_features/__init__.py: -------------------------------------------------------------------------------- 1 | """Segmentation feature extraction package for pose estimation analysis. 2 | 3 | This package provides classes and utilities to compute segmentation-based features from pose estimation data, 4 | including image moments, Hu moments, shape descriptors, and grouped segmentation features. These features 5 | facilitate quantitative analysis of object shapes, contours, and their temporal dynamics in multi-subject 6 | tracking scenarios. 7 | 8 | Modules: 9 | - hu_moments: Extraction of Hu invariant moments from segmentation contours. 10 | - moment_cache: Efficient caching and retrieval of image moments for each frame and identity. 11 | - moments: Calculation of central and normalized image moments. 12 | - segment_group: Grouping and management of segmentation-based features. 13 | - shape_descriptors: Computation of geometric shape descriptors from segmentation data. 14 | """ 15 | 16 | from .hu_moments import HuMoments 17 | from .moment_cache import MomentInfo 18 | from .moments import Moments 19 | from .segment_group import SegmentationFeatureGroup 20 | from .shape_descriptors import ShapeDescriptors 21 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/segmentation_features/hu_moments.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from jabs.feature_extraction.feature_base_class import Feature 7 | from jabs.pose_estimation import PoseEstimation 8 | 9 | if typing.TYPE_CHECKING: 10 | from .moment_cache import MomentInfo 11 | 12 | 13 | class HuMoments(Feature): 14 | """Feature for the hu image moments of the segmentation contours.""" 15 | 16 | _name = "hu_moments" 17 | _feature_names: typing.ClassVar[list[str]] = [f"hu{i}" for i in range(1, 8)] 18 | 19 | def __init__( 20 | self, poses: PoseEstimation, pixel_scale: float, moment_cache: "MomentInfo" 21 | ): 22 | super().__init__(poses, pixel_scale) 23 | self._moment_cache = moment_cache 24 | 25 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 26 | """Computes per-frame Hu image moment features for a specific identity. 27 | 28 | For each frame, calculates the seven Hu invariant moments from the cached image moments 29 | of the segmentation contours. 30 | 31 | Args: 32 | identity (int): The identity index for which to compute Hu moments. 33 | 34 | Returns: 35 | dict[str, np.ndarray]: Dictionary mapping Hu moment names ("hu1" to "hu7") to per-frame arrays of values. 36 | """ 37 | values = { 38 | name: np.zeros([self._poses.num_frames], dtype=np.float32) 39 | for name in self._feature_names 40 | } 41 | 42 | for frame in range(self._poses.num_frames): 43 | # Skip calculation if m00 is 0 44 | if self._moment_cache.get_moment(frame, "m00") == 0: 45 | continue 46 | hu_moments = cv2.HuMoments(self._moment_cache.get_all_moments(frame)) 47 | for i, name in enumerate(self._feature_names): 48 | values[name][frame] = hu_moments[i, 0] 49 | 50 | return values 51 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/segmentation_features/moment_cache.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from jabs.pose_estimation import PoseEstimationV6 5 | 6 | 7 | class MomentInfo: 8 | """this info is needed to compute a number of different image moment features. 9 | 10 | It can be done once for a given identity, and then an instance of this object can be passed into all the 11 | features that need it. Image moments provided here are adjusted for pixel scaling. 12 | 13 | get_moment(frame, key) retrieves the calculated image moment 14 | 15 | Args: 16 | poses (PoseEstimationV6): V6+ Pose estimation data for one video. 17 | identity (int): Identity to compute moments for. 18 | pixel_scale (float): Scale factor to convert pixel distances to cm. 19 | """ 20 | 21 | def __init__(self, poses: PoseEstimationV6, identity: int, pixel_scale: float): 22 | # These keys are necessary because opencv returns a dict which may not always be sorted 23 | self._moment_keys = list(cv2.moments(np.empty(0))) 24 | self._moment_conversion_powers = [ 25 | self.get_pixel_power(feature_name) for feature_name in self._moment_keys 26 | ] 27 | self._poses = poses 28 | self._pixel_scale = pixel_scale 29 | self._moments = np.zeros( 30 | (self._poses.num_frames, len(self._moment_keys)), dtype=np.float32 31 | ) 32 | self._seg_data = self._poses.get_segmentation_data(identity) 33 | self._seg_flags = self._poses.get_segmentation_flags(identity) 34 | 35 | # Parse out the contour matrix into a list of contour lists 36 | tmp_contour_data = [] 37 | for frame in range(self._moments.shape[0]): 38 | tmp_contour_data.append(self.trim_contour_list(self._seg_data[frame, ...])) 39 | 40 | self._seg_data = tmp_contour_data 41 | 42 | for frame, contours in enumerate(self._seg_data): 43 | # No segmentation data was present, skip calculating moments 44 | if len(contours) < 1: 45 | moments = dict.fromkeys(self._moment_keys, np.nan) 46 | else: 47 | moments = self.calculate_moments(contours) 48 | # Update the output array with the desired moments for each frame. 49 | for j in range(len(self._moment_keys)): 50 | self._moments[frame, j] = moments[self._moment_keys[j]] * np.power( 51 | self._pixel_scale, self._moment_conversion_powers[j] 52 | ) 53 | 54 | def get_pixel_power(self, key): 55 | """get the degree that pixels influence this image moment 56 | 57 | Args: 58 | key: key of the image moment 59 | 60 | Returns: 61 | power that should be used for converting from pixels to cm 62 | space 63 | """ 64 | # Only works for image moments 0-9 on either dimension 65 | # opencv only does the first 3 moments (0-2) 66 | return int(key[-1]) + int(key[-2]) + 2 67 | 68 | def get_moment(self, frame, key): 69 | """retrieve a single moment value 70 | 71 | Args: 72 | frame: frame to retrieve moment data 73 | key: key of moment data to retrieve 74 | 75 | Returns: 76 | moment value 77 | """ 78 | key_idx = self._moment_keys.index(key) 79 | return self._moments[frame, key_idx] 80 | 81 | def get_all_moments(self, frame): 82 | """retrieve moments for a frame 83 | 84 | Args: 85 | frame: frame to retrieve moment data 86 | 87 | Returns: 88 | dict of moment data 89 | """ 90 | return dict(zip(self._moment_keys, self._moments[frame], strict=False)) 91 | 92 | def get_trimmed_contours(self, frame): 93 | """retrieves a contour for a specific frame 94 | 95 | Args: 96 | frame: frame to retrieve contour data 97 | 98 | Returns: 99 | an opencv-complaint list of contours 100 | """ 101 | return self._seg_data[frame] 102 | 103 | def get_flags(self, frame): 104 | """retrieves the internal/external flags for a specific frame 105 | 106 | Args: 107 | frame: frame to retrieve flags 108 | 109 | Returns: 110 | a binary vector of whether the segmentation contours are 111 | external (1) or internal (0) 112 | """ 113 | return self._seg_flags[frame] 114 | 115 | def trim_contour(self, arr): 116 | """removes -1s from contour data 117 | 118 | Args: 119 | arr: contour, padded with -1s 120 | 121 | Returns: 122 | opencv-complaint contour 123 | """ 124 | assert arr.ndim == 2 125 | return_arr = arr[np.all(arr != -1, axis=1), :] 126 | if len(return_arr) > 0: 127 | return return_arr.astype(np.int32) 128 | 129 | def trim_contour_list(self, arr): 130 | """trims a fully padded 3D matrix into opencv-compliant contour list 131 | 132 | Args: 133 | arr: a full matrix of contours, padded with -1s 134 | 135 | Returns: 136 | opencv-complaint contour list 137 | """ 138 | assert arr.ndim == 3 139 | return [self.trim_contour(x) for x in arr if np.any(x != -1)] 140 | 141 | @staticmethod 142 | def calculate_moments(contour_list): 143 | """Renders the contour data onto a frame to calculate the moments 144 | 145 | Args: 146 | contour_list: list of polygons, the format opencv returns 147 | from cv2.findContours 148 | 149 | Returns: 150 | dict of cv2.moments image moments 151 | """ 152 | frame_size = np.max(np.concatenate(contour_list)) + 1 153 | # Render the contours on a frame 154 | render = np.zeros([frame_size, frame_size, 1], dtype=np.uint8) 155 | _ = cv2.drawContours(render, contour_list, -1, [1], -1) 156 | return cv2.moments(render) 157 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/segmentation_features/moments.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from jabs.feature_extraction.feature_base_class import Feature 6 | from jabs.pose_estimation import PoseEstimation 7 | 8 | if typing.TYPE_CHECKING: 9 | from .moment_cache import MomentInfo 10 | 11 | 12 | class Moments(Feature): 13 | """feature for the image moments of the contours.""" 14 | 15 | _name = "moments" 16 | # These are all the opencv moments 17 | # _feature_names = ['m00', 'm10', 'm01', 'm20', 'm11', 'm02', 'm30', 'm21', 'm12', 'm03', 'mu20', 'mu11', 'mu02', 'mu30', 'mu21', 'mu12', 'mu03', 'nu20', 'nu11', 'nu02', 'nu30', 'nu21', 'nu12', 'nu03'] 18 | # However, we only want to look at egocentric (translational invariant) 19 | # mu (central or relative to centroid) moments and nu (normalized central) moments meet this translational invariance criteria 20 | # nu moments are also scale-invariant 21 | _moments_to_use: typing.ClassVar[list[str]] = [ 22 | "m00", 23 | "mu20", 24 | "mu11", 25 | "mu02", 26 | "mu30", 27 | "mu21", 28 | "mu12", 29 | "mu03", 30 | "nu20", 31 | "nu11", 32 | "nu02", 33 | "nu30", 34 | "nu21", 35 | "nu12", 36 | "nu03", 37 | ] 38 | 39 | def __init__( 40 | self, poses: PoseEstimation, pixel_scale: float, moment_cache: "MomentInfo" 41 | ): 42 | super().__init__(poses, pixel_scale) 43 | self._moment_cache = moment_cache 44 | 45 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 46 | """Computes per-frame image moment features for a specific identity. 47 | 48 | For each frame, extracts selected translational and scale-invariant image moments 49 | (central and normalized central moments) from the cached segmentation data. 50 | 51 | Args: 52 | identity (int): The identity index for which to compute moment features. 53 | 54 | Returns: 55 | dict[str, np.ndarray]: Dictionary mapping moment names to per-frame arrays of values. 56 | """ 57 | values = {} 58 | 59 | for cur_moment in self._moments_to_use: 60 | vector = np.zeros([self._poses.num_frames], dtype=np.float32) 61 | for frame in range(self._poses.num_frames): 62 | vector[frame] = self._moment_cache.get_moment(frame, cur_moment) 63 | values[cur_moment] = vector 64 | 65 | return values 66 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/segmentation_features/segment_group.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from jabs.feature_extraction.feature_group_base_class import FeatureGroup 4 | from jabs.pose_estimation import PoseEstimation 5 | 6 | from ..feature_base_class import Feature 7 | 8 | # import all feature modules for this group 9 | from .hu_moments import HuMoments 10 | from .moment_cache import MomentInfo 11 | from .moments import Moments 12 | from .shape_descriptors import ShapeDescriptors 13 | 14 | 15 | class SegmentationFeatureGroup(FeatureGroup): 16 | """A feature group for extracting segmentation features from pose estimation data.""" 17 | 18 | _name = "segmentation" 19 | 20 | # build dictionary mapping feature name to class that implements it 21 | _features: typing.ClassVar[dict[str, Feature]] = { 22 | Moments.name(): Moments, 23 | ShapeDescriptors.name(): ShapeDescriptors, 24 | HuMoments.name(): HuMoments, 25 | } 26 | 27 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 28 | super().__init__(poses, pixel_scale) 29 | self._moments_cache = None 30 | 31 | def _init_feature_mods(self, identity: int): 32 | """initialize all of the feature modules specified in the current config 33 | 34 | Args: 35 | identity: subject identity to use when computing segmentation features 36 | 37 | Returns: 38 | dictionary of initialized feature modules for this group 39 | """ 40 | self._moments_cache = MomentInfo(self._poses, identity, self._pixel_scale) 41 | 42 | return { 43 | feature: self._features[feature]( 44 | self._poses, self._pixel_scale, self._moments_cache 45 | ) 46 | for feature in self._enabled_features 47 | } 48 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/social_features/__init__.py: -------------------------------------------------------------------------------- 1 | """Social feature extraction package. 2 | 3 | This package provides classes and utilities to compute social interaction features from pose estimation data, including 4 | pairwise distances, closest distances, and field-of-view (FoV) based metrics. These features are useful for analyzing 5 | proximity and orientation-based social behaviors in multi-subject tracking scenarios. 6 | 7 | Modules: 8 | - closest_distances: Closest distance between subject and nearest other animal. 9 | - closest_fov_angles: Angle to the closest animal within subject's FoV. 10 | - closest_fov_distances: Closest distance to animal within FoV. 11 | - pairwise_social_distances: Pairwise distances for keypoint subsets. 12 | """ 13 | 14 | from .closest_distances import ClosestDistances 15 | from .closest_fov_angles import ClosestFovAngles 16 | from .closest_fov_distances import ClosestFovDistances 17 | from .pairwise_social_distances import ( 18 | PairwiseSocialDistances, 19 | PairwiseSocialFovDistances, 20 | ) 21 | from .social_group import SocialFeatureGroup 22 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/social_features/closest_distances.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from jabs.feature_extraction.feature_base_class import Feature 6 | 7 | if typing.TYPE_CHECKING: 8 | from jabs.pose_estimation import PoseEstimation 9 | 10 | from .social_distance import ClosestIdentityInfo 11 | 12 | 13 | class ClosestDistances(Feature): 14 | """ 15 | Computes the distance between a subject and the nearest other identity for each frame. 16 | 17 | This feature calculates, for each frame, the distance between the subject and the closest other identity, 18 | based on pose estimation data. The result is useful for analyzing proximity-based social interactions. 19 | 20 | Args: 21 | poses (PoseEstimation): Pose estimation data for a video. 22 | pixel_scale (float): Scale factor to convert pixel distances to real-world units (cm). 23 | social_distance_info (ClosestIdentityInfo): Object providing pre-computed closest identity information. 24 | """ 25 | 26 | _name = "closest_distances" 27 | _min_pose = 3 28 | 29 | def __init__( 30 | self, 31 | poses: "PoseEstimation", 32 | pixel_scale: float, 33 | social_distance_info: "ClosestIdentityInfo", 34 | ): 35 | super().__init__(poses, pixel_scale) 36 | self._social_distance_info = social_distance_info 37 | 38 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 39 | """compute the value of the per frame features for a specific identity 40 | 41 | Args: 42 | identity: identity to compute features for 43 | 44 | Returns: 45 | dict with feature values 46 | """ 47 | return { 48 | "closest social distance": self._social_distance_info.compute_distances( 49 | self._social_distance_info.closest_identities 50 | ) 51 | } 52 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/social_features/closest_fov_angles.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import scipy.stats 5 | 6 | from jabs.feature_extraction.feature_base_class import Feature 7 | 8 | if typing.TYPE_CHECKING: 9 | from jabs.pose_estimation import PoseEstimation 10 | 11 | from .social_distance import ClosestIdentityInfo 12 | 13 | 14 | class ClosestFovAngles(Feature): 15 | """ 16 | Computes the angle between a subject and the closest other identity within its field of view (FoV) for each frame. 17 | 18 | This feature provides, for each frame, the angle (in degrees) from the subject to the closest other identity 19 | that is within its field of view, based on pose estimation data. The angles are treated as circular values for 20 | windowed operations. 21 | 22 | Args: 23 | poses (PoseEstimation): Pose estimation data for a video. 24 | pixel_scale (float): Scale factor to convert pixel distances to cm. 25 | social_distance_info (ClosestIdentityInfo): Object providing closest identity and FoV angle information. 26 | """ 27 | 28 | _name = "closest_fov_angles" 29 | _min_pose = 3 30 | 31 | # override for circular values 32 | _window_operations: typing.ClassVar[dict[str, typing.Callable]] = { 33 | "mean": lambda x: scipy.stats.circmean( 34 | x, low=-180, high=180, nan_policy="omit" 35 | ), 36 | "std_dev": lambda x: scipy.stats.circstd( 37 | x, low=-180, high=180, nan_policy="omit" 38 | ), 39 | } 40 | 41 | def __init__( 42 | self, 43 | poses: "PoseEstimation", 44 | pixel_scale: float, 45 | social_distance_info: "ClosestIdentityInfo", 46 | ): 47 | super().__init__(poses, pixel_scale) 48 | self._social_distance_info = social_distance_info 49 | 50 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 51 | """compute the value of the per frame features for a specific identity 52 | 53 | Args: 54 | identity: identity to compute features for 55 | 56 | Returns: 57 | dict with feature values 58 | """ 59 | # this is already computed 60 | return { 61 | "angle of closest social distance in FoV": self._social_distance_info.closest_fov_angles 62 | } 63 | 64 | def window( 65 | self, identity: int, window_size: int, per_frame_values: dict[str, np.ndarray] 66 | ) -> dict: 67 | """compute window feature values for a given identities per frame values""" 68 | # need to override to use special method for computing window features with circular values 69 | return self._window_circular(identity, window_size, per_frame_values) 70 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/social_features/closest_fov_distances.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from jabs.feature_extraction.feature_base_class import Feature 6 | 7 | if typing.TYPE_CHECKING: 8 | from jabs.pose_estimation import PoseEstimation 9 | 10 | from .social_distance import ClosestIdentityInfo 11 | 12 | 13 | class ClosestFovDistances(Feature): 14 | """Computes the closest distance between a subject and the nearest other identity within its field of view (FoV). 15 | 16 | Args: 17 | poses (PoseEstimation): Pose estimation data for one video. 18 | pixel_scale (float): Scale factor to convert pixel distances to cm. 19 | social_distance_info (ClosestIdentityInfo): Object providing closest identity and FoV information. 20 | """ 21 | 22 | _name = "closest_fov_distances" 23 | _min_pose = 3 24 | 25 | def __init__( 26 | self, 27 | poses: "PoseEstimation", 28 | pixel_scale: float, 29 | social_distance_info: "ClosestIdentityInfo", 30 | ): 31 | super().__init__(poses, pixel_scale) 32 | self._social_distance_info = social_distance_info 33 | 34 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 35 | """compute the value of the per frame features for a specific identity 36 | 37 | Args: 38 | identity: identity to compute features for 39 | 40 | Returns: 41 | dict with feature values 42 | """ 43 | return { 44 | "closest social distance in FoV": self._social_distance_info.compute_distances( 45 | self._social_distance_info.closest_fov_identities 46 | ) 47 | } 48 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/social_features/pairwise_social_distances.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | 5 | from jabs.feature_extraction.feature_base_class import Feature 6 | from jabs.pose_estimation import PoseEstimation 7 | 8 | if typing.TYPE_CHECKING: 9 | from .social_distance import ClosestIdentityInfo 10 | 11 | # For social interaction we will consider a subset 12 | # of points to capture just the most important 13 | # information for social. 14 | _social_point_subset = [ 15 | PoseEstimation.KeypointIndex.NOSE, 16 | PoseEstimation.KeypointIndex.BASE_NECK, 17 | PoseEstimation.KeypointIndex.BASE_TAIL, 18 | ] 19 | 20 | 21 | class PairwiseSocialDistances(Feature): 22 | """Computes pairwise social distances between a subject and its closest other identity for a subset of keypoints. 23 | 24 | This feature extracts, for each frame, the distances between all pairs of keypoints in a predefined subset 25 | for the subject and the closest other identity. The distances are used to characterize social interactions 26 | based on pose estimation data. 27 | 28 | Args: 29 | poses (PoseEstimation): Pose estimation data for all subjects. 30 | pixel_scale (float): Scale factor to convert pixel distances to real-world units. 31 | social_distance_info (ClosestIdentityInfo): Object providing closest identity information. 32 | 33 | Methods: 34 | per_frame(identity): Computes per-frame pairwise social distances for a given identity. 35 | """ 36 | 37 | _name = "social_pairwise_distances" 38 | _min_pose = 3 39 | 40 | # total number of values created by pairwise distances between the 41 | # subject and closest other identity for this subset of points 42 | _num_social_distances = len(_social_point_subset) ** 2 43 | 44 | def __init__( 45 | self, 46 | poses: PoseEstimation, 47 | pixel_scale: float, 48 | social_distance_info: "ClosestIdentityInfo", 49 | ): 50 | super().__init__(poses, pixel_scale) 51 | self._social_distance_info = social_distance_info 52 | self._poses = poses 53 | 54 | def per_frame(self, identity: int) -> dict: 55 | """compute the value of the per frame features for a specific identity 56 | 57 | Args: 58 | identity: identity to compute features for 59 | 60 | Returns: 61 | dict with feature values 62 | """ 63 | return self._social_distance_info.compute_pairwise_social_distances( 64 | _social_point_subset, self._social_distance_info.closest_identities 65 | ) 66 | 67 | 68 | class PairwiseSocialFovDistances(PairwiseSocialDistances): 69 | """compute pairwise social distances between subject and closest other animal in field of view 70 | 71 | nearly the same as the PairwiseSocialDistances, except closest_fov_identities is passed to 72 | compute_pairwise_social_distances rather than closest_identities 73 | """ 74 | 75 | _name = "social_pairwise_fov_distances" 76 | 77 | def per_frame(self, identity: int) -> dict[str, np.ndarray]: 78 | """compute the value of the per frame features for a specific identity 79 | 80 | Args: 81 | identity: identity to compute features for 82 | 83 | Returns: 84 | np.ndarray with feature values 85 | """ 86 | return self._social_distance_info.compute_pairwise_social_distances( 87 | _social_point_subset, self._social_distance_info.closest_fov_identities 88 | ) 89 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/social_features/social_group.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from jabs.feature_extraction.feature_group_base_class import FeatureGroup 4 | from jabs.pose_estimation import PoseEstimation 5 | 6 | from ..feature_base_class import Feature 7 | from .closest_distances import ClosestDistances 8 | from .closest_fov_angles import ClosestFovAngles 9 | from .closest_fov_distances import ClosestFovDistances 10 | from .pairwise_social_distances import ( 11 | PairwiseSocialDistances, 12 | PairwiseSocialFovDistances, 13 | ) 14 | from .social_distance import ClosestIdentityInfo 15 | 16 | 17 | class SocialFeatureGroup(FeatureGroup): 18 | """A feature group for extracting social interaction features from pose estimation data. 19 | 20 | This class manages the computation and caching of various social features, such as closest distances, 21 | field-of-view angles, and pairwise social distances, for a given subject identity. It initializes and 22 | provides access to feature modules relevant to social behavior analysis. 23 | 24 | Args: 25 | poses (PoseEstimation): Pose estimation data for a video. 26 | pixel_scale (float): Scale factor to convert pixel distances to real-world units (cm). 27 | """ 28 | 29 | _name = "social" 30 | 31 | # build dictionary mapping feature name to class that implements it 32 | _features: typing.ClassVar[dict[str, Feature]] = { 33 | ClosestDistances.name(): ClosestDistances, 34 | ClosestFovAngles.name(): ClosestFovAngles, 35 | ClosestFovDistances.name(): ClosestFovDistances, 36 | PairwiseSocialDistances.name(): PairwiseSocialDistances, 37 | PairwiseSocialFovDistances.name(): PairwiseSocialFovDistances, 38 | } 39 | 40 | def __init__(self, poses: PoseEstimation, pixel_scale: float): 41 | super().__init__(poses, pixel_scale) 42 | self._closest_identities_cache = None 43 | 44 | def _init_feature_mods(self, identity: int): 45 | """initialize all of the feature modules specified in the current config 46 | 47 | Args: 48 | identity: subject identity to use when computing social 49 | features 50 | 51 | Returns: 52 | dictionary of initialized feature modules for this group 53 | """ 54 | # cache the most recent ClosestIdentityInfo, it's needed by 55 | # the IdentityFeatures class when saving the social features to the 56 | # h5 file 57 | self._closest_identities_cache = ClosestIdentityInfo( 58 | self._poses, identity, self._pixel_scale 59 | ) 60 | 61 | # initialize all the feature modules specified in the current config 62 | return { 63 | feature: self._features[feature]( 64 | self._poses, self._pixel_scale, self._closest_identities_cache 65 | ) 66 | for feature in self._enabled_features 67 | } 68 | 69 | @property 70 | def closest_identities(self): 71 | """return cached closet identities""" 72 | return self._closest_identities_cache 73 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/window_operations/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `window_operations` package provides functions for computing statistical features over sliding windows. 3 | 4 | It includes utilities for calculating windowed statistics such as minimum, maximum, skewness, and other 5 | descriptive metrics, which are useful for feature extraction in time series and signal processing tasks. 6 | """ 7 | -------------------------------------------------------------------------------- /src/jabs/feature_extraction/window_operations/signal_stats.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | from scipy.stats import kurtosis, skew 5 | 6 | 7 | def psd_sum(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 8 | """Calculates the sum power spectral density 9 | 10 | Args: 11 | freqs: frequencies in the psd, ignored 12 | psd: power spectral density matrix 13 | 14 | Returns: 15 | sum of power 16 | """ 17 | return np.sum(psd, axis=0) 18 | 19 | 20 | def psd_max(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 21 | """Calculates the max power 22 | 23 | Args: 24 | freqs: frequencies in the psd, ignored 25 | psd: power spectral density matrix 26 | 27 | Returns: 28 | max of power 29 | """ 30 | return np.nanmax(psd, axis=0) 31 | 32 | 33 | def psd_min(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 34 | """Calculates the min power 35 | 36 | Args: 37 | freqs: frequencies in the psd, ignored 38 | psd: power spectral density matrix 39 | 40 | Returns: 41 | min of power 42 | """ 43 | return np.min(psd, axis=0) 44 | 45 | 46 | def psd_mean(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 47 | """Calculates the mean power spectral density 48 | 49 | Args: 50 | freqs: frequencies in the psd, ignored 51 | psd: power spectral density matrix 52 | 53 | Returns: 54 | mean of power 55 | """ 56 | return np.mean(psd, axis=0) 57 | 58 | 59 | def psd_mean_band( 60 | freqs: np.ndarray, 61 | psd: np.ndarray, 62 | band_low: int = 0, 63 | band_high: float = np.finfo(np.float64).max, 64 | ) -> np.ndarray: 65 | """Calculates the mean power spectral density in a band 66 | 67 | Args: 68 | freqs: frequencies in the psd, ignored 69 | psd: power spectral density matrix 70 | band_low: lower bound of the frequency band 71 | band_high: upper bound of the frequency band 72 | 73 | Returns: 74 | mean of power 75 | """ 76 | idx = np.logical_and(freqs >= band_low, freqs < band_high) 77 | 78 | if not np.any(idx): 79 | return np.full([psd.shape[1]], np.nan) 80 | return np.mean(np.asarray(psd)[idx], axis=0) 81 | 82 | 83 | def psd_median(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 84 | """Calculates the median power spectral density 85 | 86 | Args: 87 | freqs: frequencies in the psd, ignored 88 | psd: power spectral density matrix 89 | 90 | Returns: 91 | median of power 92 | """ 93 | return np.median(psd, axis=0) 94 | 95 | 96 | def psd_std_dev(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 97 | """Calculates the standard deviation power spectral density 98 | 99 | Args: 100 | freqs: frequencies in the psd, ignored 101 | psd: power spectral density matrix 102 | 103 | Returns: 104 | standard deviation of power 105 | """ 106 | return np.std(psd, axis=0) 107 | 108 | 109 | def psd_kurtosis(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 110 | """Calculates the kurtosis power spectral density 111 | 112 | Args: 113 | freqs: frequencies in the psd, ignored 114 | psd: power spectral density matrix 115 | 116 | Returns: 117 | kurtosis of power 118 | """ 119 | with warnings.catch_warnings(): 120 | warnings.simplefilter("ignore", category=RuntimeWarning) 121 | return_values = kurtosis(psd, axis=0, nan_policy="omit") 122 | # If infinity shows up, convert to nan 123 | return_values = np.nan_to_num( 124 | return_values, nan=np.nan, posinf=np.nan, neginf=np.nan 125 | ) 126 | return return_values 127 | 128 | 129 | def psd_skew(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 130 | """Calculates the skew power spectral density 131 | 132 | Args: 133 | freqs: frequencies in the psd, ignored 134 | psd: power spectral density matrix 135 | 136 | Returns: 137 | skew of power 138 | """ 139 | with warnings.catch_warnings(): 140 | warnings.simplefilter("ignore", category=RuntimeWarning) 141 | return_values = skew(psd, axis=0, nan_policy="omit") 142 | # If infinity shows up, convert to nan 143 | return_values = np.nan_to_num( 144 | return_values, nan=np.nan, posinf=np.nan, neginf=np.nan 145 | ) 146 | return return_values 147 | 148 | 149 | def psd_peak_freq(freqs: np.ndarray, psd: np.ndarray) -> np.ndarray: 150 | """Calculates the frequency with the most power 151 | 152 | Args: 153 | freqs: frequencies in the psd 154 | psd: power spectral density matrix 155 | 156 | Returns: 157 | frequency with highest power 158 | """ 159 | return freqs[np.argmax(psd, axis=0)] 160 | -------------------------------------------------------------------------------- /src/jabs/pose_estimation/__init__.py: -------------------------------------------------------------------------------- 1 | """JABS pose file handler module""" 2 | 3 | import re 4 | from pathlib import Path 5 | 6 | import h5py 7 | 8 | from .pose_est import MINIMUM_CONFIDENCE, PoseEstimation, PoseHashException 9 | from .pose_est_v2 import PoseEstimationV2 10 | from .pose_est_v3 import PoseEstimationV3 11 | from .pose_est_v4 import PoseEstimationV4 12 | from .pose_est_v5 import PoseEstimationV5 13 | from .pose_est_v6 import PoseEstimationV6 14 | 15 | 16 | def open_pose_file(path: Path, cache_dir: Path | None = None): 17 | """open a pose file using the correct PoseEstimation subclass based on the version implied by the filename""" 18 | if path.name.endswith("v2.h5"): 19 | return PoseEstimationV2(path, cache_dir) 20 | elif path.name.endswith("v3.h5"): 21 | return PoseEstimationV3(path, cache_dir) 22 | elif path.name.endswith("v4.h5"): 23 | return PoseEstimationV4(path, cache_dir) 24 | elif path.name.endswith("v5.h5"): 25 | return PoseEstimationV5(path, cache_dir) 26 | elif path.name.endswith("v6.h5"): 27 | return PoseEstimationV6(path, cache_dir) 28 | else: 29 | raise ValueError("not a valid pose estimate filename") 30 | 31 | 32 | def get_pose_path(video_path: Path): 33 | """take a path to a video file and return the path to the corresponding pose_est h5 file 34 | 35 | Args: 36 | video_path: Path to video file in project 37 | 38 | Returns: 39 | Path object representing location of corresponding pose_est h5 file 40 | 41 | Raises: 42 | ValueError: if video_path does not have corresponding pose_est file 43 | """ 44 | file_base = video_path.with_suffix("") 45 | 46 | # default to the highest version pose file for a video 47 | if video_path.with_name(file_base.name + "_pose_est_v6.h5").exists(): 48 | return video_path.with_name(file_base.name + "_pose_est_v6.h5") 49 | elif video_path.with_name(file_base.name + "_pose_est_v5.h5").exists(): 50 | return video_path.with_name(file_base.name + "_pose_est_v5.h5") 51 | elif video_path.with_name(file_base.name + "_pose_est_v4.h5").exists(): 52 | return video_path.with_name(file_base.name + "_pose_est_v4.h5") 53 | elif video_path.with_name(file_base.name + "_pose_est_v3.h5").exists(): 54 | return video_path.with_name(file_base.name + "_pose_est_v3.h5") 55 | elif video_path.with_name(file_base.name + "_pose_est_v2.h5").exists(): 56 | return video_path.with_name(file_base.name + "_pose_est_v2.h5") 57 | else: 58 | raise ValueError("Video does not have pose file") 59 | 60 | 61 | def get_pose_file_major_version(path: Path): 62 | """get the major version of a pose file from the _filename_ 63 | 64 | Note: does not inspect contents of file, assumes file name matches 65 | video_name_v[version number].h5 66 | 67 | Args: 68 | path: path of pose file 69 | 70 | Returns: 71 | integer major version number 72 | """ 73 | v = re.search(r"_v([0-9])+\.h5", str(path)).group(1) 74 | return int(v) 75 | 76 | 77 | def get_frames_from_file(path: Path): 78 | """peek into a pose_est file to count number of frames""" 79 | with h5py.File(path, "r") as pose_h5: 80 | vid_grp = pose_h5["poseest"] 81 | return vid_grp["points"].shape[0] 82 | 83 | 84 | def get_static_objects_in_file(path: Path): 85 | """peek into a pose file to get a list of the static objects it contains 86 | 87 | Args: 88 | path: path of pose file 89 | 90 | Returns: 91 | list of static object names contained in pose file 92 | """ 93 | if get_pose_file_major_version(path) >= 5: 94 | with h5py.File(path, "r") as pose_h5: 95 | if "static_objects" in pose_h5: 96 | return list(pose_h5["static_objects"].keys()) 97 | return [] 98 | 99 | 100 | def get_points_per_lixit(path: Path) -> int: 101 | """inspect a pose file to get the number of keypoints per lixit 102 | 103 | returns zero if the pose file does not have any lixit keypoints. 104 | """ 105 | points_per_lixit = 0 106 | if get_pose_file_major_version(path) >= 5: 107 | with h5py.File(path, "r") as pose_h5: 108 | if "static_objects" in pose_h5 and "lixit" in pose_h5["static_objects"]: 109 | if pose_h5["static_objects"]["lixit"].ndim == 3: 110 | points_per_lixit = 3 111 | else: 112 | points_per_lixit = 1 113 | return points_per_lixit 114 | -------------------------------------------------------------------------------- /src/jabs/pose_estimation/pose_est_v2.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import h5py 4 | import numpy as np 5 | 6 | from .pose_est import MINIMUM_CONFIDENCE, PoseEstimation 7 | 8 | 9 | class PoseEstimationV2(PoseEstimation): 10 | """read in pose_est_v2.h5 file""" 11 | 12 | def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30): 13 | """initialize new object from h5 file 14 | 15 | Args: 16 | file_path: path to pose_est_v2.h5 file 17 | cache_dir: optional cache directory, used to cache convex 18 | hulls for faster loading 19 | fps: frames per second, used for scaling time series 20 | featuresfrom "per frame" to "per second" 21 | """ 22 | super().__init__(file_path, cache_dir, fps) 23 | 24 | # we will make this look like the PoseEstimationV3 but with a single 25 | # identity so the main program won't care which type it is 26 | self._identities = [0] 27 | self._max_instances = 1 28 | 29 | self._path = file_path 30 | 31 | # open the hdf5 pose file 32 | with h5py.File(self._path, "r") as pose_h5: 33 | # extract data from the HDF5 file 34 | pose_grp = pose_h5["poseest"] 35 | 36 | # load contents 37 | # keypoints are stored as (y,x) 38 | self._points = np.flip(pose_grp["points"][:].astype(np.float64), axis=-1) 39 | self._point_mask = np.zeros(self._points.shape[:-1], dtype=np.uint16) 40 | self._point_mask[:] = pose_grp["confidence"][:] > MINIMUM_CONFIDENCE 41 | 42 | # get pixel size 43 | self._cm_per_pixel = pose_grp.attrs.get("cm_per_pixel", None) 44 | 45 | self._num_frames = self._points.shape[0] 46 | 47 | # build an array that indicates if the identity exists for a each frame 48 | # require at least 3 body points, not just tail 49 | init_func = np.vectorize( 50 | lambda x: 0 if np.sum(self._point_mask[x][:-2]) < 3 else 1, 51 | otypes=[np.uint8], 52 | ) 53 | self._identity_mask = np.fromfunction( 54 | init_func, (self._num_frames,), dtype=np.int_ 55 | ) 56 | 57 | @property 58 | def identity_to_track(self): 59 | """get the identity to track mapping 60 | 61 | for pose_est_v2, this is always None because jabs doesn't do any track to identity mapping for the single 62 | mouse pose files 63 | """ 64 | return None 65 | 66 | @property 67 | def format_major_version(self): 68 | """get the major version of the pose file format""" 69 | return 2 70 | 71 | def get_points(self, frame_index: int, identity: int, scale: float | None = None): 72 | """return points and point masks for an individual frame 73 | 74 | Args: 75 | frame_index: frame index of points and masks to be returned 76 | identity: included for compatibility with pose_est_v3. 77 | Should always be zero. 78 | scale: optional scale factor, set to cm_per_pixel to convert 79 | poses from pixel coordinates to cm coordinates 80 | 81 | Returns: 82 | numpy array of points (12,2), numpy array of point masks (12,) 83 | """ 84 | if identity not in self.identities: 85 | raise ValueError("Invalid identity") 86 | 87 | if not self._identity_mask[frame_index]: 88 | return None, None 89 | 90 | if scale is not None: 91 | return self._points[frame_index] * scale, self._point_mask[frame_index] 92 | else: 93 | return self._points[frame_index], self._point_mask[frame_index] 94 | 95 | def get_identity_poses(self, identity: int, scale: float | None = None): 96 | """return all points and point masks 97 | 98 | Args: 99 | identity: included for compatibility with pose_est_v3. 100 | Should always be zero. 101 | scale: optional scale factor, set to cm_per_pixel to convert 102 | poses from pixel coordinates to cm coordinates 103 | 104 | Returns: 105 | numpy array of points (#frames, 12, 2), numpy array of point masks (#frames, 12) 106 | """ 107 | if identity not in self.identities: 108 | raise ValueError("Invalid identity") 109 | 110 | if scale is not None: 111 | return self._points * scale, self._point_mask 112 | else: 113 | return self._points, self._point_mask 114 | 115 | def identity_mask(self, identity): 116 | """get the identity mask (indicates if specified identity is present in each frame) 117 | 118 | Args: 119 | identity: included for compatibility with pose_est_v3 interface. Should always be zero. 120 | 121 | Returns: 122 | numpy array of size (#frames,) 123 | """ 124 | if identity not in self.identities: 125 | raise ValueError("Invalid identity") 126 | return self._identity_mask 127 | 128 | def get_identity_point_mask(self, identity): 129 | """get the point mask array for a given identity 130 | 131 | Args: 132 | identity: identity to return point mask for 133 | 134 | Returns: 135 | array of point masks (#frames, 12) 136 | """ 137 | return self._point_mask 138 | -------------------------------------------------------------------------------- /src/jabs/pose_estimation/pose_est_v5.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import h5py 4 | import numpy as np 5 | 6 | from .pose_est_v4 import PoseEstimationV4 7 | 8 | OBJECTS_STORED_YX = [ 9 | "lixit", 10 | "food_hopper", 11 | ] 12 | 13 | 14 | class PoseEstimationV5(PoseEstimationV4): 15 | """Pose estimation handler for version 5 pose files with static object support. 16 | 17 | Extends PoseEstimationV4 to add reading and management of static object data 18 | (such as lixit and food hopper positions) from pose v5 HDF5 files. Handles 19 | additional datasets introduced in v5, including logic for different lixit 20 | keypoint configurations. 21 | 22 | Args: 23 | file_path (Path): Path to the pose HDF5 file. 24 | cache_dir (Path | None): Optional cache directory for intermediate data. 25 | fps (int): Frames per second for the video. 26 | 27 | """ 28 | 29 | def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30): 30 | super().__init__(file_path, cache_dir, fps) 31 | 32 | # V5 files are the same as V4, except they have some additional datasets 33 | # in addition to the posest data. The pose data is all loaded from 34 | # calling super().__init__(), so now we just need to load the additional 35 | # data 36 | 37 | self._static_objects = {} 38 | self._lixit_keypoints = 0 39 | 40 | # open the hdf5 pose file 41 | with h5py.File(self._path, "r") as pose_h5: 42 | # extract data from the HDF5 file 43 | for g in pose_h5: 44 | # skip over the poseest group, since that's already been 45 | # processed 46 | if g == "poseest": 47 | continue 48 | 49 | # v5 adds a 'static_objects' dataset, but is otherwise the same as v4 50 | if g == "static_objects": 51 | for d in pose_h5["static_objects"]: 52 | static_object_data = pose_h5["static_objects"][d][:] 53 | if d in OBJECTS_STORED_YX: 54 | static_object_data = np.flip(static_object_data, axis=-1) 55 | self._static_objects[d] = static_object_data 56 | 57 | if "lixit" in self._static_objects: 58 | # drop "lixit" from the static objects if it is an empty array 59 | if self._static_objects["lixit"].shape[0] == 0: 60 | del self._static_objects["lixit"] 61 | else: 62 | # if the lixit data is not empty, we need to get the number of 63 | # keypoints in the lixit data 64 | if self._static_objects["lixit"].ndim == 3: 65 | # if the lixit data is 3D, it means we have 3 points per 66 | # lixit (tip, left side, right side -- in that order) and the shape is #lixit x 3 x 2 67 | self._lixit_keypoints = 3 68 | else: 69 | # if the lixit data is 2D, it means we have 1 point per 70 | # lixit (tip) and the shape is #lixit x 2 71 | self._lixit_keypoints = 1 72 | 73 | @property 74 | def format_major_version(self) -> int: 75 | """get the major version of the pose file format""" 76 | return 5 77 | 78 | @property 79 | def num_lixit_keypoints(self) -> int: 80 | """get the number of lixit keypoints""" 81 | return self._lixit_keypoints 82 | -------------------------------------------------------------------------------- /src/jabs/project/__init__.py: -------------------------------------------------------------------------------- 1 | """jabs project module""" 2 | 3 | from .export_training import export_training_data 4 | from .project import Project 5 | from .read_training import load_training_data 6 | from .track_labels import TrackLabels 7 | from .video_labels import VideoLabels 8 | 9 | __all__ = [ 10 | "Project", 11 | "TrackLabels", 12 | "VideoLabels", 13 | "export_training_data", 14 | "load_training_data", 15 | ] 16 | -------------------------------------------------------------------------------- /src/jabs/project/export_training.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | from pathlib import Path 3 | from typing import TYPE_CHECKING 4 | 5 | import h5py 6 | import numpy as np 7 | 8 | import jabs.feature_extraction 9 | import jabs.version 10 | from jabs.project.project_utils import to_safe_name 11 | 12 | # these are used for type hints, but cause circular imports 13 | # TYPE_CHECKING is always false at runtime, so this gets around that 14 | # also requires enclosing Project and Classifier type hints in quotes 15 | if TYPE_CHECKING: 16 | from jabs.project import Project 17 | from jabs.types import ClassifierType 18 | 19 | 20 | def export_training_data( 21 | project: "Project", 22 | behavior: str, 23 | pose_version: int, 24 | classifier_type: "ClassifierType", 25 | training_seed: int, 26 | out_file: Path | None = None, 27 | ): 28 | """ 29 | Export labeled training data from a JABS project for classifier retraining. 30 | 31 | This function extracts features and labels for a specified behavior and writes them, 32 | along with relevant project and classifier metadata, to an HDF5 file. The exported 33 | file can be used for retraining classifiers outside the current environment. 34 | 35 | Args: 36 | project (Project): The JABS project to export data from. 37 | behavior (str): Name of the behavior to export. 38 | pose_version (int): Minimum required pose version for the classifier. 39 | classifier_type (ClassifierType): The classifier type for which data is exported. 40 | training_seed (int): Random seed to ensure reproducible training splits. 41 | out_file (Path, optional): Output file path. If None, a file is created in the 42 | project directory with a timestamped name. 43 | 44 | Returns: 45 | Path: The path to the exported HDF5 file. 46 | 47 | Raises: 48 | OSError: If the output file cannot be created or written. 49 | """ 50 | ts = datetime.now().strftime("%Y%m%d_%H%M%S") 51 | features, group_mapping = project.get_labeled_features(behavior) 52 | 53 | if out_file is None: 54 | out_file = project.dir / f"{to_safe_name(behavior)}_training_{ts}.h5" 55 | 56 | string_type = h5py.special_dtype(vlen=str) 57 | 58 | with h5py.File(out_file, "w") as out_h5: 59 | out_h5.attrs["file_version"] = jabs.feature_extraction.FEATURE_VERSION 60 | out_h5.attrs["app_version"] = jabs.version.version_str() 61 | out_h5.attrs["min_pose_version"] = pose_version 62 | out_h5.attrs["behavior"] = behavior 63 | write_project_settings( 64 | out_h5, project.settings_manager.get_behavior(behavior), "settings" 65 | ) 66 | out_h5.attrs["classifier_type"] = classifier_type.value 67 | out_h5.attrs["training_seed"] = training_seed 68 | feature_group = out_h5.create_group("features") 69 | for feature, data in features["per_frame"].items(): 70 | feature_group.create_dataset(f"per_frame/{feature}", data=data) 71 | for feature, data in features["window"].items(): 72 | feature_group.create_dataset(f"window/{feature}", data=data) 73 | 74 | out_h5.create_dataset("group", data=features["groups"]) 75 | out_h5.create_dataset("label", data=features["labels"]) 76 | 77 | # store the video/identity to group mapping in the h5 file 78 | for group in group_mapping: 79 | dset = out_h5.create_dataset( 80 | f"group_mapping/{group}/identity", (1,), dtype=np.int64 81 | ) 82 | dset[:] = group_mapping[group]["identity"] 83 | dset = out_h5.create_dataset( 84 | f"group_mapping/{group}/video_name", (1,), dtype=string_type 85 | ) 86 | dset[:] = group_mapping[group]["video"] 87 | 88 | # return output path, so if it was generated automatically the caller 89 | # will know 90 | return out_file 91 | 92 | 93 | def write_project_settings( 94 | h5_file: h5py.File | h5py.Group, settings: dict, node: str = "settings" 95 | ): 96 | """write project settings to a training h5 file recursively 97 | 98 | Args: 99 | h5_file: open h5 file to write to 100 | settings: dict of project settings 101 | node: name of the node to write to 102 | """ 103 | current_group = h5_file.require_group(node) 104 | for key, val in settings.items(): 105 | if type(val) is dict: 106 | write_project_settings(current_group, val, key) 107 | else: 108 | current_group.create_dataset(key, data=val) 109 | -------------------------------------------------------------------------------- /src/jabs/project/project_paths.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | class ProjectPaths: 5 | """Class to manage project paths.""" 6 | 7 | __JABS_DIR = "jabs" 8 | __PROJECT_FILE = "project.json" 9 | 10 | def __init__(self, base_path: Path, use_cache: bool = True): 11 | self._base_path = base_path 12 | 13 | self._jabs_dir = base_path / self.__JABS_DIR 14 | self._annotations_dir = self._jabs_dir / "annotations" 15 | self._feature_dir = self._jabs_dir / "features" 16 | self._prediction_dir = self._jabs_dir / "predictions" 17 | self._classifier_dir = self._jabs_dir / "classifiers" 18 | self._archive_dir = self._jabs_dir / "archive" 19 | self._cache_dir = self._jabs_dir / "cache" if use_cache else None 20 | 21 | self._project_file = self._jabs_dir / self.__PROJECT_FILE 22 | 23 | @property 24 | def project_dir(self) -> Path: 25 | """Get the base path of the project.""" 26 | return self._base_path 27 | 28 | @property 29 | def jabs_dir(self) -> Path: 30 | """Get the path to the JABS directory.""" 31 | return self._jabs_dir 32 | 33 | @property 34 | def annotations_dir(self) -> Path: 35 | """Get the path to the annotations directory.""" 36 | return self._annotations_dir 37 | 38 | @property 39 | def feature_dir(self) -> Path: 40 | """Get the path to the features directory.""" 41 | return self._feature_dir 42 | 43 | @property 44 | def prediction_dir(self) -> Path: 45 | """Get the path to the predictions directory.""" 46 | return self._prediction_dir 47 | 48 | @property 49 | def project_file(self) -> Path: 50 | """Get the path to the project file.""" 51 | return self._project_file 52 | 53 | @property 54 | def classifier_dir(self) -> Path: 55 | """Get the path to the classifiers directory.""" 56 | return self._classifier_dir 57 | 58 | @property 59 | def archive_dir(self) -> Path: 60 | """Get the path to the archive directory.""" 61 | return self._archive_dir 62 | 63 | @property 64 | def cache_dir(self) -> Path | None: 65 | """Get the path to the cache directory.""" 66 | return self._cache_dir 67 | 68 | def create_directories(self): 69 | """Create all necessary directories for the project.""" 70 | self._annotations_dir.mkdir(parents=True, exist_ok=True) 71 | self._feature_dir.mkdir(parents=True, exist_ok=True) 72 | self._prediction_dir.mkdir(parents=True, exist_ok=True) 73 | self._classifier_dir.mkdir(parents=True, exist_ok=True) 74 | self._archive_dir.mkdir(parents=True, exist_ok=True) 75 | 76 | if self._cache_dir: 77 | self._cache_dir.mkdir(parents=True, exist_ok=True) 78 | -------------------------------------------------------------------------------- /src/jabs/project/project_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def to_safe_name(behavior: str) -> str: 5 | """Create a version of the given behavior name that should be safe to use in filenames. 6 | 7 | Args: 8 | behavior: string behavior name 9 | 10 | Returns: 11 | sanitized behavior name 12 | 13 | Raises: 14 | ValueError: if the behavior name is empty after sanitization 15 | """ 16 | safe_behavior = re.sub(r"[^\w.-]+", "_", behavior, flags=re.UNICODE) 17 | # get rid of consecutive underscores 18 | safe_behavior = re.sub("_{2,}", "_", safe_behavior) 19 | 20 | # Remove leading and trailing underscores 21 | safe_behavior = safe_behavior.lstrip("_").rstrip("_") 22 | 23 | if safe_behavior == "": 24 | raise ValueError("Behavior name is empty after sanitization.") 25 | return safe_behavior 26 | -------------------------------------------------------------------------------- /src/jabs/project/read_training.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import h5py 4 | import pandas as pd 5 | 6 | from jabs.types import ClassifierType, ProjectDistanceUnit 7 | 8 | 9 | def read_project_settings(h5_file: h5py.Group) -> dict: 10 | """read dict of project settings 11 | 12 | Args: 13 | h5_file: open h5 file to read settings from 14 | 15 | Returns: 16 | dictionary of all project settings 17 | """ 18 | all_settings = {} 19 | root_len = len(h5_file.name) + 1 20 | 21 | def _walk_project_settings(name, node) -> dict: 22 | """read dict of project settings walker 23 | 24 | Args: 25 | name: root where node is located 26 | node: name of node currently visiting 27 | 28 | Returns: 29 | dictionary of walked setting (if valid node) 30 | 31 | meant to be used with h5py's visititems 32 | this walk can't use return/yield, so we just mutate the dict each visit 33 | settings can only be a max of 1 deep 34 | """ 35 | fullname = node.name[root_len:] 36 | if isinstance(node, h5py.Dataset): 37 | if "/" in fullname: 38 | level_name, key = fullname.split("/") 39 | level_settings = all_settings.get(level_name, {}) 40 | level_settings.update({key: node[...].item()}) 41 | all_settings.update({level_name: level_settings}) 42 | else: 43 | all_settings.update({fullname: node[...].item()}) 44 | 45 | h5_file.visititems(_walk_project_settings) 46 | return all_settings 47 | 48 | 49 | def load_training_data(training_file: Path): 50 | """load training data from file 51 | 52 | Args: 53 | training_file: path to training h5 file 54 | 55 | Returns: 56 | features, group_mapping features: dict containing training data 57 | with the following format: { 58 | 'per_frame': {} 59 | 'window_features': {}, 60 | 'labels': [int], 61 | 'groups': [int], 62 | 'behavior': str, 63 | 'settings': {}, 64 | 'classifier': 65 | } 66 | 67 | group_mapping: dict containing group to identity/video mapping: 68 | { 69 | group_id: { 70 | 'identity': int, 71 | 'video': str 72 | }, 73 | } 74 | """ 75 | features = {"per_frame": {}, "window": {}} 76 | group_mapping = {} 77 | 78 | with h5py.File(training_file, "r") as in_h5: 79 | features["min_pose_version"] = in_h5.attrs["min_pose_version"] 80 | features["behavior"] = in_h5.attrs["behavior"] 81 | features["settings"] = read_project_settings(in_h5["settings"]) 82 | features["training_seed"] = in_h5.attrs["training_seed"] 83 | features["classifier_type"] = ClassifierType(in_h5.attrs["classifier_type"]) 84 | # convert the string distance_unit attr to corresponding 85 | # ProjectDistanceUnit enum 86 | unit = in_h5.attrs.get("distance_unit") 87 | if unit is None: 88 | # if the training file doesn't include distance_unit it is old and 89 | # definitely used pixel based distances 90 | features["distance_unit"] = ProjectDistanceUnit.PIXEL 91 | else: 92 | features["distance_unit"] = ProjectDistanceUnit[unit] 93 | 94 | features["labels"] = in_h5["label"][:] 95 | features["groups"] = in_h5["group"][:] 96 | 97 | # per frame features 98 | for name, val in in_h5["features/per_frame"].items(): 99 | features["per_frame"][name] = val[:] 100 | features["per_frame"] = pd.DataFrame(features["per_frame"]) 101 | # window features 102 | for name, val in in_h5["features/window"].items(): 103 | features["window"][name] = val[:] 104 | features["window"] = pd.DataFrame(features["window"]) 105 | 106 | # extract the group mapping from h5 file 107 | for name, val in in_h5["group_mapping"].items(): 108 | group_mapping[int(name)] = { 109 | "identity": val["identity"][0], 110 | "video": val["video_name"][0], 111 | } 112 | 113 | # load required extended features 114 | if "extended_features" in in_h5: 115 | features["extended_features"] = {} 116 | for group in in_h5["extended_features"]: 117 | features["extended_features"][group] = [] 118 | for f in in_h5[f"extended_features/{group}"]: 119 | features["extended_features"][group].append(f.decode("utf-8")) 120 | else: 121 | features["extended_features"] = None 122 | 123 | return features, group_mapping 124 | -------------------------------------------------------------------------------- /src/jabs/project/settings_manager.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing 3 | 4 | import jabs.feature_extraction as feature_extraction 5 | from jabs.version import version_str 6 | 7 | if typing.TYPE_CHECKING: 8 | from .project_paths import ProjectPaths 9 | 10 | 11 | class SettingsManager: 12 | """Class to manage project properties/settings.""" 13 | 14 | def __init__(self, project_paths: "ProjectPaths"): 15 | """Initialize the ProjectProperties. 16 | 17 | Args: 18 | project_paths: ProjectPaths object to manage file paths. 19 | """ 20 | self._paths = project_paths 21 | self._project_info = self._load_project_file() 22 | 23 | def _load_project_file(self) -> dict: 24 | """Load project properties from the project file. 25 | 26 | Returns: 27 | Dictionary of project properties. 28 | """ 29 | try: 30 | with self._paths.project_file.open(mode="r", newline="\n") as f: 31 | settings = json.load(f) 32 | except FileNotFoundError: 33 | settings = {} 34 | 35 | # Ensure default keys exist 36 | settings.setdefault("behavior", {}) 37 | settings.setdefault("window_sizes", [feature_extraction.DEFAULT_WINDOW_SIZE]) 38 | 39 | return settings 40 | 41 | def save_project_file(self, data: dict | None = None): 42 | """Save project properties & settings to the project file. 43 | 44 | Args: 45 | data: Dictionary with state information to save. 46 | """ 47 | # Merge data with current metadata 48 | if data is not None: 49 | self._project_info.update(data) 50 | 51 | self._project_info["version"] = version_str() 52 | 53 | # Save combined info to file 54 | with self._paths.project_file.open(mode="w", newline="\n") as f: 55 | json.dump(self._project_info, f, indent=2, sort_keys=True) 56 | 57 | @property 58 | def project_settings(self) -> dict: 59 | """Get a copy of the current project properties and settings. 60 | 61 | Returns: 62 | dict 63 | """ 64 | return dict(self._project_info) 65 | 66 | def save_behavior(self, behavior: str, data: dict): 67 | """Save a behavior to project file. 68 | 69 | Args: 70 | behavior: Behavior name. 71 | data: Dictionary of behavior settings. 72 | """ 73 | defaults = self._project_info.get("defaults", {}) 74 | 75 | all_behavior_data = self._project_info.get("behavior", {}) 76 | merged_data = all_behavior_data.get(behavior, defaults) 77 | merged_data.update(data) 78 | 79 | all_behavior_data[behavior] = merged_data 80 | self.save_project_file({"behavior": all_behavior_data}) 81 | 82 | def get_behavior(self, behavior: str) -> dict: 83 | """Get metadata specific to a requested behavior. 84 | 85 | Args: 86 | behavior: Behavior key to read. 87 | 88 | Returns: 89 | Dictionary of behavior metadata. 90 | """ 91 | return self._project_info.get("behavior", {}).get(behavior, {}) 92 | 93 | def remove_behavior(self, behavior: str) -> None: 94 | """remove behavior from project settings""" 95 | try: 96 | del self._project_info["behavior"][behavior] 97 | self.save_project_file() 98 | except KeyError: 99 | pass 100 | 101 | def update_version(self): 102 | """Update the version number in the metadata if it differs from the current version.""" 103 | current_version = self._project_info.get("version") 104 | if current_version != version_str(): 105 | self.save_project_file({"version": version_str()}) 106 | -------------------------------------------------------------------------------- /src/jabs/resources/__init__.py: -------------------------------------------------------------------------------- 1 | """Resource file paths. 2 | 3 | This package provides package-aware access to application resources 4 | such as documentation files and icons using `importlib.resources`. 5 | 6 | Attributes: 7 | DOCS_DIR (pathlib.Path): Path object to the documentation directory. 8 | ICON_PATH (pathlib.Path): Path object to the application icon. 9 | """ 10 | 11 | import importlib.resources 12 | 13 | DOCS_DIR = importlib.resources.files("jabs.resources") / "docs" 14 | ICON_PATH = importlib.resources.files("jabs.resources") / "icon.png" 15 | 16 | 17 | __all__ = [ 18 | "DOCS_DIR", 19 | "ICON_PATH", 20 | ] 21 | -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/classifier_controls.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/classifier_controls.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/identity_gaps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/identity_gaps.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/label_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/label_viz.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/main_window.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/main_window.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/pose_overlay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/pose_overlay.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/selecting_frames.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/selecting_frames.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/stacked_timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/stacked_timeline.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/timeline_menu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/timeline_menu.png -------------------------------------------------------------------------------- /src/jabs/resources/docs/user_guide/imgs/track_overlay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/docs/user_guide/imgs/track_overlay.png -------------------------------------------------------------------------------- /src/jabs/resources/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/src/jabs/resources/icon.png -------------------------------------------------------------------------------- /src/jabs/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | """jabs scripts module""" 2 | 3 | from .gui_entrypoint import main 4 | 5 | __all__ = [ 6 | "main", 7 | ] 8 | -------------------------------------------------------------------------------- /src/jabs/scripts/generate_features.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """initialize JABS features for a pose file 4 | 5 | computes features if they do not exist 6 | """ 7 | 8 | import argparse 9 | import sys 10 | from pathlib import Path 11 | 12 | from jabs.feature_extraction.features import IdentityFeatures 13 | from jabs.pose_estimation import open_pose_file 14 | from jabs.project import Project 15 | from jabs.types import ProjectDistanceUnit 16 | 17 | 18 | def generate_feature_cache(args): 19 | """ 20 | Generate and cache features for each identity in a pose file. 21 | 22 | Loads the pose file, computes features for each identity using the specified 23 | settings, and caches the results in the given feature directory. Optionally 24 | generates window features if a window size is provided. 25 | 26 | Args: 27 | args: argparse Namespace object containing script arguments, including pose file path, 28 | pose version, feature directory, distance unit, window size, and fps. 29 | """ 30 | distance_unit = ( 31 | ProjectDistanceUnit.CM if args.cm_units else ProjectDistanceUnit.PIXEL 32 | ) 33 | settings = Project.settings_by_pose_version(args.pose_version, distance_unit) 34 | if args.window_size is not None: 35 | settings["window_size"] = args.window_size 36 | cache_window = True 37 | else: 38 | cache_window = False 39 | 40 | pose_est = open_pose_file(args.pose_file) 41 | for curr_id in pose_est.identities: 42 | # Note: Features are still cached with the highest pose version. 43 | # It isn't until get_features is called that filtering occurs 44 | features = IdentityFeatures( 45 | args.pose_file, 46 | curr_id, 47 | args.feature_dir, 48 | pose_est, 49 | fps=args.fps, 50 | op_settings=settings, 51 | cache_window=cache_window, 52 | ) 53 | # Window features are not automatically generated. 54 | if cache_window: 55 | _ = features.get_window_features( 56 | settings["window_size"], settings["social"], force=True 57 | ) 58 | 59 | 60 | def main(): 61 | """jabs-features""" 62 | parser = argparse.ArgumentParser(prog=f"{script_name()} features") 63 | parser.add_argument( 64 | "--pose-file", 65 | required=True, 66 | type=Path, 67 | help="pose file to compute features for", 68 | ) 69 | parser.add_argument( 70 | "--pose-version", 71 | required=True, 72 | type=int, 73 | help="pose version to calculate features", 74 | ) 75 | parser.add_argument( 76 | "--feature-dir", 77 | required=True, 78 | type=Path, 79 | help="directory to write output features", 80 | ) 81 | parser.add_argument( 82 | "--use-cm-distances", 83 | action="store_true", 84 | dest="cm_units", 85 | default=False, 86 | help="use cm distance units instead of pixel", 87 | ) 88 | parser.add_argument( 89 | "--window-size", 90 | type=int, 91 | default=None, 92 | help="window size for features (default none)", 93 | ) 94 | parser.add_argument( 95 | "--fps", default=30, help="frames per second to use for feature calculation" 96 | ) 97 | args = parser.parse_args() 98 | 99 | generate_feature_cache(args) 100 | 101 | 102 | def script_name() -> str: 103 | """return the script name""" 104 | return Path(sys.argv[0]).name 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | -------------------------------------------------------------------------------- /src/jabs/scripts/gui_entrypoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | from PySide6 import QtWidgets 5 | from PySide6.QtGui import QIcon 6 | 7 | from jabs.constants import APP_NAME, APP_NAME_LONG 8 | from jabs.resources import ICON_PATH 9 | from jabs.ui import MainWindow 10 | 11 | 12 | def main(): 13 | """main entrypoint for JABS video labeling and classifier GUI 14 | 15 | takes one optional positional argument: path to project directory 16 | """ 17 | app = QtWidgets.QApplication(sys.argv) 18 | app.setWindowIcon(QIcon(str(ICON_PATH))) 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("project_dir", nargs="?") 22 | args = parser.parse_args() 23 | 24 | main_window = MainWindow(app_name=APP_NAME, app_name_long=APP_NAME_LONG) 25 | main_window.show() 26 | if main_window.show_license_dialog() != QtWidgets.QDialog.DialogCode.Accepted: 27 | sys.exit(1) 28 | 29 | if args.project_dir is not None: 30 | try: 31 | main_window.open_project(args.project_dir) 32 | except Exception as e: 33 | sys.exit(f"Error opening project: {e}") 34 | 35 | # user accepted license terms, run the main application loop 36 | sys.exit(app.exec()) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /src/jabs/scripts/stats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | 4 | import numpy as np 5 | from tabulate import tabulate 6 | 7 | from jabs.classifier import Classifier 8 | from jabs.project import load_training_data 9 | from jabs.scripts.classify import train 10 | from jabs.types import ProjectDistanceUnit 11 | 12 | 13 | def main(): 14 | """jabs-stats""" 15 | parser = argparse.ArgumentParser( 16 | prog="stats", 17 | description="print accuracy statistics for the given classifier", 18 | ) 19 | 20 | parser.add_argument( 21 | "-k", 22 | help="the parameter controlling the maximum number of iterations." 23 | " Default is to iterate over all leave-one-out possibilities.", 24 | type=int, 25 | ) 26 | 27 | parser.add_argument( 28 | "training", 29 | help="training data HDF5 file", 30 | ) 31 | 32 | args = parser.parse_args() 33 | 34 | classifier = train(args.training) 35 | 36 | features, group_mapping = load_training_data(args.training) 37 | data_generator = Classifier.leave_one_group_out( 38 | features["per_frame"], 39 | features["window"], 40 | features["labels"], 41 | features["groups"], 42 | ) 43 | 44 | table_rows = [] 45 | accuracies = [] 46 | fbeta_behavior = [] 47 | fbeta_notbehavior = [] 48 | 49 | iter_count = 0 50 | for i, data in enumerate(itertools.islice(data_generator, args.k)): 51 | iter_count += 1 52 | 53 | test_info = group_mapping[data["test_group"]] 54 | 55 | # train classifier, and then use it to classify our test data 56 | classifier.train(data) 57 | predictions = classifier.predict(data["test_data"]) 58 | 59 | # calculate some performance metrics using the classifications of 60 | # the test data 61 | accuracy = classifier.accuracy_score(data["test_labels"], predictions) 62 | pr = classifier.precision_recall_score(data["test_labels"], predictions) 63 | confusion = classifier.confusion_matrix(data["test_labels"], predictions) 64 | 65 | table_rows.append( 66 | [ 67 | accuracy, 68 | pr[0][0], 69 | pr[0][1], 70 | pr[1][0], 71 | pr[1][1], 72 | pr[2][0], 73 | pr[2][1], 74 | f"{test_info['video']} [{test_info['identity']}]", 75 | ] 76 | ) 77 | accuracies.append(accuracy) 78 | fbeta_behavior.append(pr[2][1]) 79 | fbeta_notbehavior.append(pr[2][0]) 80 | 81 | # print performance metrics and feature importance to console 82 | print("-" * 70) 83 | print(f"training iteration {i}") 84 | print("TEST DATA:") 85 | print(f"\tVideo: {test_info['video']}") 86 | print(f"\tIdentity: {test_info['identity']}") 87 | print(f"ACCURACY: {accuracy * 100:.2f}%") 88 | print("PRECISION RECALL:") 89 | print(f" {'not behavior':12} behavior") 90 | print(f" precision {pr[0][0]:<12.8} {pr[0][1]:<.8}") 91 | print(f" recall {pr[1][0]:<12.8} {pr[1][1]:<.8}") 92 | print(f" fbeta score {pr[2][0]:<12.8} {pr[2][1]:<.8}") 93 | print(f" support {pr[3][0]:<12} {pr[3][1]}") 94 | print("CONFUSION MATRIX:") 95 | print(f"{confusion}") 96 | print("-" * 70) 97 | 98 | print("Top 10 features by importance:") 99 | classifier.print_feature_importance(data["feature_names"], 10) 100 | 101 | if iter_count >= 1: 102 | print("\n" + "=" * 70) 103 | print("SUMMARY\n") 104 | print( 105 | tabulate( 106 | table_rows, 107 | showindex="always", 108 | headers=[ 109 | "accuracy", 110 | "precision\n(not behavior)", 111 | "precision\n(behavior)", 112 | "recall\n(not behavior)", 113 | "recall\n(behavior)", 114 | "f beta score\n(not behavior)", 115 | "f beta score\n(behavior)", 116 | "test - leave one out:\n(video [identity])", 117 | ], 118 | ) 119 | ) 120 | 121 | print(f"\nmean accuracy: {np.mean(accuracies):.5}") 122 | print(f"mean fbeta score (behavior): {np.mean(fbeta_behavior):.5}") 123 | print(f"mean fbeta score (not behavior): {np.mean(fbeta_notbehavior):.5}") 124 | print(f"\nClassifier: {classifier.classifier_name}") 125 | print(f"Behavior: {features['behavior']}") 126 | unit = ( 127 | "cm" 128 | if classifier.project_settings["cm_units"] == ProjectDistanceUnit.CM 129 | else "pixel" 130 | ) 131 | print(f"Feature Distance Unit: {unit}") 132 | print("-" * 70) 133 | else: 134 | print("No results calculated") 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /src/jabs/types/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for defining enums used in JABS""" 2 | 3 | from .classifier_types import ClassifierType 4 | from .units import ProjectDistanceUnit 5 | -------------------------------------------------------------------------------- /src/jabs/types/classifier_types.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | 3 | 4 | class ClassifierType(IntEnum): 5 | """Classifier type for the project.""" 6 | 7 | RANDOM_FOREST = 1 8 | GRADIENT_BOOSTING = 2 9 | XGBOOST = 3 10 | -------------------------------------------------------------------------------- /src/jabs/types/units.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | 4 | class ProjectDistanceUnit(enum.IntEnum): 5 | """Distance unit for the project.""" 6 | 7 | PIXEL = 0 8 | CM = 1 9 | -------------------------------------------------------------------------------- /src/jabs/ui/__init__.py: -------------------------------------------------------------------------------- 1 | """JABS UI Module""" 2 | 3 | from .main_window import MainWindow 4 | 5 | __all__ = ["MainWindow"] 6 | -------------------------------------------------------------------------------- /src/jabs/ui/about_dialog.py: -------------------------------------------------------------------------------- 1 | from PySide6.QtCore import Qt 2 | from PySide6.QtWidgets import QDialog, QLabel, QPushButton, QVBoxLayout 3 | 4 | from jabs.version import version_str 5 | 6 | 7 | class AboutDialog(QDialog): 8 | """dialog that shows application info such as version and copyright""" 9 | 10 | def __init__(self, app_name, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.setWindowTitle(f"About {app_name}") 13 | 14 | layout = QVBoxLayout() 15 | 16 | layout.addWidget( 17 | QLabel(f"Version: {version_str()}"), alignment=Qt.AlignmentFlag.AlignCenter 18 | ) 19 | 20 | label = QLabel( 21 | f"{app_name} developed by the " 22 | "Kumar Lab " 23 | "at The Jackson Laboratory" 24 | ) 25 | label.setOpenExternalLinks(True) 26 | layout.addWidget(label, alignment=Qt.AlignmentFlag.AlignCenter) 27 | 28 | layout.addWidget( 29 | QLabel("Copyright 2025 The Jackson Laboratory. All Rights Reserved"), 30 | alignment=Qt.AlignmentFlag.AlignCenter, 31 | ) 32 | 33 | email_label = QLabel("jabs@jax.org") 34 | email_label.setOpenExternalLinks(True) 35 | layout.addWidget(email_label, alignment=Qt.AlignmentFlag.AlignCenter) 36 | 37 | ok_button = QPushButton("OK") 38 | ok_button.clicked.connect(self.close) 39 | 40 | layout.addWidget(ok_button, alignment=Qt.AlignmentFlag.AlignLeft) 41 | 42 | self.setLayout(layout) 43 | -------------------------------------------------------------------------------- /src/jabs/ui/archive_behavior_dialog.py: -------------------------------------------------------------------------------- 1 | from PySide6 import QtCore 2 | from PySide6.QtWidgets import ( 3 | QCheckBox, 4 | QComboBox, 5 | QDialog, 6 | QHBoxLayout, 7 | QPushButton, 8 | QVBoxLayout, 9 | ) 10 | 11 | 12 | class ArchiveBehaviorDialog(QDialog): 13 | """dialog to allow a user to select a behavior to archive from the project""" 14 | 15 | behavior_archived = QtCore.Signal(str) 16 | 17 | def __init__(self, behaviors: [str], *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | 20 | self._behavior_selection = QComboBox() 21 | self._behavior_selection.addItems(behaviors) 22 | self._behavior_selection.currentIndexChanged.connect( 23 | self.__behavior_selection_changed 24 | ) 25 | 26 | self._confirm = QCheckBox("Confirm", self) 27 | self._confirm.setChecked(False) 28 | self._confirm.stateChanged.connect(self.__confirm_checkbox_changed) 29 | 30 | self._archive_button = QPushButton("Archive") 31 | self._archive_button.setEnabled(False) 32 | self._archive_button.clicked.connect(self.__archive) 33 | cancel_button = QPushButton("Close") 34 | cancel_button.clicked.connect(self.close) 35 | 36 | button_layout = QHBoxLayout() 37 | button_layout.addWidget(cancel_button) 38 | button_layout.addWidget(self._archive_button) 39 | 40 | layout = QVBoxLayout() 41 | layout.addWidget(self._behavior_selection) 42 | layout.addWidget(self._confirm) 43 | layout.addLayout(button_layout) 44 | 45 | self.setLayout(layout) 46 | 47 | def __confirm_checkbox_changed(self, state: bool): 48 | self._archive_button.setEnabled(state) 49 | 50 | def __behavior_selection_changed(self): 51 | self._confirm.setChecked(False) 52 | 53 | def __archive(self): 54 | # remove selected behavior from combo box 55 | behavior = self._behavior_selection.currentText() 56 | self.__remove_behavior(behavior) 57 | 58 | # if there are no other behaviors that can be archived then hide the dialog 59 | if self._behavior_selection.count() == 0: 60 | self.hide() 61 | 62 | # emit the signal to handle archiving the behavior 63 | self.behavior_archived.emit(behavior) 64 | 65 | # after emitting the signal, we can close the dialog if there are no more behaviors in the drop-down 66 | if self._behavior_selection.count() == 0: 67 | self.done(1) 68 | 69 | def __remove_behavior(self, behavior: str): 70 | idx = self._behavior_selection.findText(behavior, QtCore.Qt.MatchExactly) 71 | if idx != -1: 72 | self._behavior_selection.removeItem(idx) 73 | -------------------------------------------------------------------------------- /src/jabs/ui/classification_thread.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from PySide6 import QtCore 4 | 5 | from jabs.feature_extraction import DEFAULT_WINDOW_SIZE, IdentityFeatures 6 | from jabs.video_reader.utilities import get_fps 7 | 8 | 9 | class ClassifyThread(QtCore.QThread): 10 | """thread to run the classification to keep the main GUI thread responsive""" 11 | 12 | done = QtCore.Signal(dict) 13 | update_progress = QtCore.Signal(int) 14 | current_status = QtCore.Signal(str) 15 | 16 | def __init__(self, classifier, project, behavior, current_video, parent=None): 17 | super().__init__(parent=parent) 18 | self._classifier = classifier 19 | self._project = project 20 | self._behavior = behavior 21 | self._tasks_complete = 0 22 | self._current_video = current_video 23 | 24 | def run(self): 25 | """thread's main function. 26 | 27 | runs the classifier for each identity in each video 28 | """ 29 | self._tasks_complete = 0 30 | 31 | predictions = {} 32 | probabilities = {} 33 | frame_indexes = {} 34 | 35 | project_settings = self._project.settings_manager.get_behavior(self._behavior) 36 | 37 | # iterate over each video in the project 38 | for video in self._project.video_manager.videos: 39 | video_path = self._project.video_manager.video_path(video) 40 | 41 | # load the poses for this video 42 | pose_est = self._project.load_pose_est(video_path) 43 | # fps used to scale some features from per pixel time unit to 44 | # per second 45 | fps = get_fps(str(video_path)) 46 | 47 | # make predictions for each identity in this video 48 | predictions[video] = {} 49 | probabilities[video] = {} 50 | frame_indexes[video] = {} 51 | 52 | for identity in pose_est.identities: 53 | self.current_status.emit(f"Classifying {video}, Identity {identity}") 54 | 55 | # get the features for this identity 56 | features = IdentityFeatures( 57 | video, 58 | identity, 59 | self._project.feature_dir, 60 | pose_est, 61 | fps=fps, 62 | op_settings=project_settings, 63 | ) 64 | feature_values = features.get_features( 65 | project_settings.get("window_size", DEFAULT_WINDOW_SIZE) 66 | ) 67 | 68 | # reformat the data in a single 2D numpy array to pass 69 | # to the classifier 70 | per_frame_features = pd.DataFrame( 71 | IdentityFeatures.merge_per_frame_features( 72 | feature_values["per_frame"] 73 | ) 74 | ) 75 | window_features = pd.DataFrame( 76 | IdentityFeatures.merge_window_features(feature_values["window"]) 77 | ) 78 | data = self._classifier.combine_data( 79 | per_frame_features, window_features 80 | ) 81 | 82 | if data.shape[0] > 0: 83 | # make predictions 84 | predictions[video][identity] = self._classifier.predict(data) 85 | 86 | # also get the probabilities 87 | prob = self._classifier.predict_proba(data) 88 | # Save the probability for the predicted class only. 89 | # The following code uses some 90 | # numpy magic to use the _predictions array as column indexes 91 | # for each row of the 'prob' array we just computed. 92 | probabilities[video][identity] = prob[ 93 | np.arange(len(prob)), predictions[video][identity] 94 | ] 95 | 96 | # save the indexes for the predicted frames 97 | frame_indexes[video][identity] = feature_values["frame_indexes"] 98 | else: 99 | predictions[video][identity] = np.array(0) 100 | probabilities[video][identity] = np.array(0) 101 | frame_indexes[video][identity] = np.array(0) 102 | self._tasks_complete += 1 103 | self.update_progress.emit(self._tasks_complete) 104 | 105 | # save predictions 106 | self.current_status.emit("Saving Predictions") 107 | self._project.save_predictions( 108 | predictions, probabilities, frame_indexes, self._behavior, self._classifier 109 | ) 110 | 111 | self._tasks_complete += 1 112 | self.update_progress.emit(self._tasks_complete) 113 | self.done.emit( 114 | { 115 | "predictions": predictions[self._current_video], 116 | "probabilities": probabilities[self._current_video], 117 | "frame_indexes": frame_indexes[self._current_video], 118 | } 119 | ) 120 | -------------------------------------------------------------------------------- /src/jabs/ui/colors.py: -------------------------------------------------------------------------------- 1 | POSITION_MARKER_COLOR = (231, 66, 126) 2 | SELECTION_COLOR = (255, 255, 0) 3 | 4 | BACKGROUND_COLOR = (128, 128, 128, 255) 5 | NOT_BEHAVIOR_COLOR = (0, 86, 229, 255) 6 | BEHAVIOR_COLOR = (255, 159, 0, 255) 7 | -------------------------------------------------------------------------------- /src/jabs/ui/k_fold_slider_widget.py: -------------------------------------------------------------------------------- 1 | from PySide6 import QtWidgets 2 | from PySide6.QtCore import Qt, Signal 3 | 4 | 5 | class KFoldSliderWidget(QtWidgets.QWidget): 6 | """widget to allow user to select k parameter for k-fold cross validation 7 | 8 | basically consists of a QSlider and three QLabel widgets with 9 | no spacing/margins 10 | """ 11 | 12 | valueChanged = Signal(int) 13 | 14 | def __init__(self, kmax=10, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | 17 | self._slider = QtWidgets.QSlider(Qt.Orientation.Horizontal) 18 | self._slider.setMinimum(0) 19 | self._slider.setMaximum(kmax) 20 | self._slider.setTickInterval(1) 21 | self._slider.setValue(1) 22 | self._slider.setTickPosition(QtWidgets.QSlider.TickPosition.TicksBelow) 23 | self._slider.valueChanged.connect(self.valueChanged) 24 | self._slider.setFocusPolicy(Qt.FocusPolicy.NoFocus) 25 | 26 | # slider range labels 27 | label_min = QtWidgets.QLabel("0") 28 | label_min.setAlignment(Qt.AlignmentFlag.AlignLeft) 29 | label_max = QtWidgets.QLabel(f"{kmax}") 30 | label_max.setAlignment(Qt.AlignmentFlag.AlignRight) 31 | 32 | slider_vbox = QtWidgets.QVBoxLayout() 33 | slider_hbox = QtWidgets.QHBoxLayout() 34 | slider_hbox.setContentsMargins(0, 0, 0, 0) 35 | slider_vbox.setContentsMargins(0, 0, 0, 0) 36 | slider_vbox.setSpacing(0) 37 | slider_vbox.addWidget(QtWidgets.QLabel("Cross Validation k:")) 38 | slider_vbox.addWidget(self._slider) 39 | slider_vbox.addLayout(slider_hbox) 40 | slider_hbox.addWidget(label_min, Qt.AlignmentFlag.AlignLeft) 41 | slider_hbox.addWidget(label_max, Qt.AlignmentFlag.AlignRight) 42 | 43 | self.setLayout(slider_vbox) 44 | 45 | def value(self): 46 | """return the slider value""" 47 | return self._slider.value() 48 | -------------------------------------------------------------------------------- /src/jabs/ui/license_dialog.py: -------------------------------------------------------------------------------- 1 | from PySide6.QtCore import Qt 2 | from PySide6.QtWidgets import QDialog, QHBoxLayout, QLabel, QPushButton, QVBoxLayout 3 | 4 | from ..constants import APP_NAME, APP_NAME_LONG 5 | 6 | 7 | class LicenseAgreementDialog(QDialog): 8 | """Dialog for accepting the application license agreement. 9 | 10 | Presents the user with a message to accept or reject the license terms for the application. 11 | Provides YES and NO buttons to confirm or decline the agreement. 12 | 13 | Args: 14 | *args: Additional positional arguments for QDialog. 15 | **kwargs: Additional keyword arguments for QDialog. 16 | """ 17 | 18 | def __init__(self, *args, **kwargs): 19 | super().__init__(*args, **kwargs) 20 | self.setWindowTitle(f"Accept {APP_NAME_LONG} License") 21 | self.setModal(True) 22 | 23 | layout = QVBoxLayout() 24 | 25 | layout.addWidget( 26 | QLabel(f"I have read and I agree to the {APP_NAME} license terms."), 27 | alignment=Qt.AlignmentFlag.AlignCenter, 28 | ) 29 | 30 | button_layout = QHBoxLayout() 31 | 32 | yes_button = QPushButton("YES") 33 | yes_button.clicked.connect(self.accept) 34 | 35 | no_button = QPushButton("NO") 36 | no_button.clicked.connect(self.reject) 37 | 38 | button_layout.addStretch() 39 | button_layout.addWidget(yes_button, alignment=Qt.AlignmentFlag.AlignRight) 40 | button_layout.addWidget(no_button, alignment=Qt.AlignmentFlag.AlignRight) 41 | 42 | layout.addLayout(button_layout) 43 | 44 | self.setLayout(layout) 45 | -------------------------------------------------------------------------------- /src/jabs/ui/player_widget/__init__.py: -------------------------------------------------------------------------------- 1 | """Player widget module.""" 2 | 3 | from .player_widget import PlayerWidget 4 | 5 | __all__ = [ 6 | "PlayerWidget", 7 | ] 8 | -------------------------------------------------------------------------------- /src/jabs/ui/project_loader_thread.py: -------------------------------------------------------------------------------- 1 | from PySide6.QtCore import QThread, Signal, SignalInstance 2 | 3 | from jabs.project import Project 4 | 5 | 6 | class ProjectLoaderThread(QThread): 7 | """JABS Project Loader Thread 8 | 9 | This thread is used to load a JABS project in the background so that the main 10 | GUI thread remains responsive. It emits signals when the project is loaded. 11 | """ 12 | 13 | project_loaded: SignalInstance = Signal() 14 | load_error: SignalInstance = Signal(Exception) 15 | 16 | def __init__(self, project_path: str, parent=None): 17 | super().__init__(parent) 18 | self._project_path = project_path 19 | self._project = None 20 | 21 | def run(self): 22 | """Run the thread.""" 23 | # Open the project, this can take a while 24 | try: 25 | self._project = Project(self._project_path) 26 | self.project_loaded.emit() 27 | except Exception as e: 28 | # if there was an exception, we'll emit the Exception as a signal so that 29 | # the main GUI thread can handle it 30 | self.load_error.emit(e) 31 | 32 | @property 33 | def project(self): 34 | """Return the loaded project.""" 35 | return self._project 36 | -------------------------------------------------------------------------------- /src/jabs/ui/stacked_timeline_widget/__init__.py: -------------------------------------------------------------------------------- 1 | """Package providing the StackedTimelineWidget for visualizing label and prediction timelines. 2 | 3 | This package includes: 4 | - StackedTimelineWidget: A widget that displays label and prediction overviews for multiple identities, 5 | allowing users to view labels and predictions in a stacked timeline format. Also used for indicating 6 | which frames are currently selected for labeling. 7 | 8 | The widget supports switching between different identities, selection modes, and view modes for efficient 9 | annotation and review. 10 | """ 11 | 12 | from .stacked_timeline_widget import StackedTimelineWidget 13 | -------------------------------------------------------------------------------- /src/jabs/ui/stacked_timeline_widget/frame_labels_widget.py: -------------------------------------------------------------------------------- 1 | from PySide6.QtCore import QSize, Qt 2 | from PySide6.QtGui import QFont, QFontMetrics, QPainter 3 | from PySide6.QtWidgets import QApplication, QSizePolicy, QWidget 4 | 5 | 6 | class FrameLabelsWidget(QWidget): 7 | """Widget for drawing frame ticks and labels below a LabelOverviewWidget. 8 | 9 | Displays tick marks and frame numbers for a sliding window of frames centered around the current frame. 10 | Intended to visually indicate frame positions and intervals in a video labeling interface. 11 | 12 | Args: 13 | *args: Additional positional arguments for QWidget. 14 | **kwargs: Additional keyword arguments for QWidget. 15 | """ 16 | 17 | def __init__(self, *args, **kwargs): 18 | super().__init__(*args, **kwargs) 19 | 20 | # allow widget to expand horizontally but maintain fixed vertical size 21 | self.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Fixed) 22 | 23 | # number of frames on each side of current frame to include in 24 | # sliding window 25 | # this needs to match what is set in ManualLabelsWidget, so once 26 | # we make this configurable, it needs to get set in both locations 27 | self._window_size = 100 28 | 29 | # total number of frames being displayed in sliding window 30 | self._nframes = self._window_size * 2 + 1 31 | 32 | # number of frames between ticks/labels 33 | self._tick_interval = 50 34 | 35 | # current position 36 | self._current_frame = 0 37 | 38 | # information about the video needed to properly render widget 39 | self._num_frames = 0 40 | 41 | # size each frame takes up in the bar in pixels 42 | self._frame_width = self.size().width() // self._nframes 43 | self._adjusted_width = self._nframes * self._frame_width 44 | self._offset = (self.size().width() - self._adjusted_width) / 2 45 | 46 | self._font = QFont("Arial", 12) 47 | self._font_metrics = QFontMetrics(self._font) 48 | self._font_height = self._font_metrics.height() 49 | 50 | def sizeHint(self): 51 | """Give an initial starting size. 52 | 53 | Width hint is not so important because we allow the widget to resize 54 | horizontally to fill the available container. The height is fixed, 55 | so the value used here sets the height of the widget. 56 | """ 57 | return QSize(400, self._font_height + 10) 58 | 59 | def resizeEvent(self, event): 60 | """handle resize events""" 61 | self._frame_width = self.size().width() // self._nframes 62 | self._adjusted_width = self._nframes * self._frame_width 63 | self._offset = (self.size().width() - self._adjusted_width) / 2 64 | 65 | def paintEvent(self, event): 66 | """override QWidget paintEvent 67 | 68 | This draws the widget. 69 | """ 70 | if self._num_frames == 0: 71 | return 72 | 73 | # starting and ending frames of the current view 74 | start = self._current_frame - self._window_size 75 | end = self._current_frame + self._window_size 76 | 77 | qp = QPainter(self) 78 | # make the ticks the same color as the text 79 | qp.setBrush(QApplication.palette().text().color()) 80 | qp.setFont(self._font) 81 | self._draw_ticks(qp, start, end) 82 | qp.end() 83 | 84 | def _draw_ticks(self, painter, start, end): 85 | """draw ticks at the proper interval and draw the frame number under the tick 86 | 87 | Args: 88 | painter: active QPainter 89 | start: starting frame number 90 | end: ending frame number 91 | """ 92 | for i in range(start, end + 1): 93 | if (0 <= i <= self._num_frames) and i % self._tick_interval == 0: 94 | offset = self._offset + ((i - start + 0.5) * self._frame_width) - 1 95 | painter.setPen(Qt.PenStyle.NoPen) 96 | painter.drawRect(offset, 0, 2, 8) 97 | 98 | label_text = f"{i}" 99 | label_width = self._font_metrics.horizontalAdvance(label_text) 100 | painter.setPen(QApplication.palette().text().color()) 101 | painter.drawText( 102 | offset - label_width / 2 + 1, self._font_height + 8, label_text 103 | ) 104 | 105 | def set_current_frame(self, current_frame): 106 | """called to reposition the view around new current frame""" 107 | self._current_frame = current_frame 108 | self.update() 109 | 110 | def set_num_frames(self, num_frames): 111 | """set number of frames in current video 112 | 113 | this is used to keep from drawing ticks past the end of the video 114 | """ 115 | self._num_frames = num_frames 116 | -------------------------------------------------------------------------------- /src/jabs/ui/stacked_timeline_widget/label_overview_widget/__init__.py: -------------------------------------------------------------------------------- 1 | """Package providing widgets for label and prediction overview in the user interface. 2 | 3 | This package includes: 4 | - LabelOverviewWidget: Widget for displaying labels. 5 | - PredictionOverviewWidget: Widget for displaying prediction data. 6 | 7 | These widgets are used to visualize labels and predictions for different identities within the application. While the 8 | user is labeling, the LabelOverviewWidget is used to indicate which frames are currently selected for labeling. The 9 | PredictionOverviewWidget is used to display the predictions made by the model. 10 | """ 11 | 12 | from .label_overview_widget import LabelOverviewWidget 13 | from .prediction_overview_widget import PredictionOverviewWidget 14 | 15 | __all__ = [ 16 | "LabelOverviewWidget", 17 | "PredictionOverviewWidget", 18 | ] 19 | -------------------------------------------------------------------------------- /src/jabs/ui/stacked_timeline_widget/label_overview_widget/prediction_overview_widget.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .label_overview_widget import LabelOverviewWidget 4 | from .predicted_label_widget import PredictedLabelWidget 5 | from .timeline_prediction_widget import TimelinePredictionWidget 6 | 7 | 8 | class PredictionOverviewWidget(LabelOverviewWidget): 9 | """Widget that displays an overview of predicted labels and global inference results for a video. 10 | 11 | This widget replaces the manual label and timeline widgets of LabelOverviewWidget with 12 | widgets specialized for visualizing model predictions. It provides methods to set 13 | prediction data and disables setting manual labels. 14 | """ 15 | 16 | @classmethod 17 | def _timeline_widget_factory(cls, parent): 18 | return TimelinePredictionWidget(parent) 19 | 20 | @classmethod 21 | def _label_widget_factory(cls, parent): 22 | return PredictedLabelWidget(parent) 23 | 24 | def set_labels(self, labels: np.ndarray, probabilities: np.ndarray): 25 | """set prediction data to display 26 | 27 | overrides the set_labels method of LabelOverviewWidget to set predictions instead of manual labels. 28 | 29 | Args: 30 | labels (np.ndarray): Array of predicted labels. 31 | probabilities (np.ndarray): Array of prediction probabilities corresponding to the labels. 32 | """ 33 | self._label_widget.set_labels(labels, probabilities) 34 | self._timeline_widget.set_labels(labels) 35 | self.update_labels() 36 | 37 | def reset(self): 38 | """Reset the widget to its initial state.""" 39 | self._timeline_widget.reset() 40 | self._label_widget.set_labels(None, None) 41 | self._num_frames = 0 42 | -------------------------------------------------------------------------------- /src/jabs/ui/stacked_timeline_widget/label_overview_widget/timeline_prediction_widget.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PySide6.QtCore import Qt 3 | from PySide6.QtGui import QImage, QPainter, QPixmap 4 | 5 | from jabs.project import TrackLabels 6 | 7 | from .timeline_label_widget import TimelineLabelWidget 8 | 9 | 10 | class TimelinePredictionWidget(TimelineLabelWidget): 11 | """TimelinePredictionWidget 12 | 13 | subclass of TimelineLabelWidget with some modifications: 14 | self._labels will be a numpy array and not a TrackLabels object 15 | also uses opacity to indicate the confidence of the label 16 | 17 | Args: 18 | *args: Additional positional arguments for QWidget. 19 | **kwargs: Additional keyword arguments for QWidget. 20 | """ 21 | 22 | def __init__(self, *args, **kwargs) -> None: 23 | super().__init__(*args, **kwargs) 24 | 25 | def _update_bar(self) -> None: 26 | """Update the timeline bar pixmap with downsampled label colors. 27 | 28 | Overrides _update_bar() from parent class to use a np.ndarray as input instead of a TrackLabels object. 29 | Downsamples the label array to match the current pixmap width, maps labels to RGBA colors, and renders 30 | the color bar as a QPixmap for display. 31 | """ 32 | if self._labels is None: 33 | return 34 | 35 | width = self.size().width() 36 | height = self.size().height() 37 | 38 | # create a pixmap with a width that evenly divides the total number of 39 | # frames so that each pixel along the width represents a bin of frames 40 | # (_update_scale() has done this, we can use pixmap_offset to figure 41 | # out how many pixels of padding will be on each side of the final 42 | # pixmap) 43 | pixmap_width = width - 2 * self._pixmap_offset 44 | 45 | self._pixmap = QPixmap(pixmap_width, height) 46 | self._pixmap.fill(Qt.GlobalColor.transparent) 47 | 48 | downsampled = TrackLabels.downsample(self._labels, pixmap_width) 49 | 50 | # use downsampled labels to generate RGBA colors 51 | # labels are -1, 0, 1, 2 so add 1 to the downsampled labels to convert to indices in color_lut 52 | colors = self.COLOR_LUT[downsampled + 1] # shape (width, 4) 53 | color_bar = np.repeat( 54 | colors[np.newaxis, :, :], self._bar_height, axis=0 55 | ) # shape (bar_height, width, 4) 56 | 57 | img = QImage( 58 | color_bar.data, 59 | color_bar.shape[1], 60 | color_bar.shape[0], 61 | QImage.Format.Format_RGBA8888, 62 | ) 63 | painter = QPainter(self._pixmap) 64 | painter.drawImage(0, self._bar_padding, img) 65 | painter.end() 66 | -------------------------------------------------------------------------------- /src/jabs/ui/user_guide_viewer_widget/__init__.py: -------------------------------------------------------------------------------- 1 | """User guide viewer widget module.""" 2 | 3 | from .user_guide_dialog import UserGuideDialog 4 | 5 | __all__ = ["UserGuideDialog"] 6 | -------------------------------------------------------------------------------- /src/jabs/ui/user_guide_viewer_widget/user_guide_dialog.py: -------------------------------------------------------------------------------- 1 | import markdown2 2 | from PySide6.QtCore import Qt, QUrl 3 | from PySide6.QtWebEngineWidgets import QWebEngineView 4 | from PySide6.QtWidgets import QDialog, QPushButton, QVBoxLayout 5 | 6 | from jabs.resources import DOCS_DIR 7 | 8 | 9 | class UserGuideDialog(QDialog): 10 | """dialog that shows html rendering of user guide""" 11 | 12 | def __init__(self, app_name: str, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.setWindowTitle(f"{app_name} User Guide") 15 | self.resize(1000, 600) 16 | self._web_engine_view = QWebEngineView() 17 | self._load_content() 18 | 19 | layout = QVBoxLayout() 20 | 21 | layout.addWidget(self._web_engine_view) 22 | 23 | close_button = QPushButton("CLOSE") 24 | close_button.clicked.connect(self.close) 25 | layout.addWidget(close_button, alignment=Qt.AlignmentFlag.AlignLeft) 26 | 27 | self.setLayout(layout) 28 | 29 | def _load_content(self): 30 | user_guide_path = DOCS_DIR / "user_guide" / "user_guide.md" 31 | 32 | # need to specify a base URL when displaying the html content due to 33 | # the relative img urls in the user_guide.md document 34 | base_url = QUrl(f"{user_guide_path.parent.as_uri()}/") 35 | 36 | def error_html(message): 37 | return f""" 38 | 39 |
40 | 55 | 56 | 57 |{e}
" 72 | ) 73 | except UnicodeDecodeError: 74 | html = error_html("{e}
" 78 | ) 79 | 80 | self._web_engine_view.setHtml(html, baseUrl=base_url) 81 | -------------------------------------------------------------------------------- /src/jabs/ui/video_list_widget.py: -------------------------------------------------------------------------------- 1 | from PySide6 import QtCore, QtWidgets 2 | 3 | 4 | class _VideoListWidget(QtWidgets.QListWidget): 5 | """QListView that has been modified to not allow deselecting current selection without selecting a new row""" 6 | 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self.setSelectionMode(QtWidgets.QAbstractItemView.SelectionMode.SingleSelection) 10 | self.setSortingEnabled(True) 11 | self.setEditTriggers(QtWidgets.QAbstractItemView.EditTrigger.NoEditTriggers) 12 | 13 | # don't take focus otherwise up/down arrows will change video 14 | # when the user is intending to skip forward/back frames 15 | self.setFocusPolicy(QtCore.Qt.FocusPolicy.NoFocus) 16 | 17 | def selectionCommand(self, index, event=None): 18 | """Override to prevent deselection of the current row.""" 19 | if self.selectedIndexes() and self.selectedIndexes()[0].row() == index.row(): 20 | return QtCore.QItemSelectionModel.SelectionFlag.NoUpdate 21 | return super().selectionCommand(index, event) 22 | 23 | 24 | class VideoListDockWidget(QtWidgets.QDockWidget): 25 | """dock for listing video files associated with the project.""" 26 | 27 | selectionChanged = QtCore.Signal(str) 28 | 29 | def __init__(self, *args, **kwargs): 30 | super().__init__(*args, **kwargs) 31 | self.setWindowTitle("Project Videos") 32 | self.file_list = _VideoListWidget(self) 33 | self.setWidget(self.file_list) 34 | self._project = None 35 | self.file_list.currentItemChanged.connect(self._selection_changed) 36 | 37 | def _selection_changed(self, current, _): 38 | """Emit signal when the selected video changes.""" 39 | if current: 40 | self.selectionChanged.emit(current.text()) 41 | 42 | def set_project(self, project): 43 | """Update the video list with the active project's videos and select first video in list.""" 44 | self._project = project 45 | self.file_list.clear() 46 | self.file_list.addItems(self._project.video_manager.videos) 47 | if self._project.video_manager.videos: 48 | self.file_list.setCurrentRow(0) 49 | -------------------------------------------------------------------------------- /src/jabs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """JABS utilities""" 2 | 3 | from .utilities import get_bool_env_var, hash_file, hide_stderr 4 | 5 | # a hard coded random seed used for the final training done with all 6 | # training data before saving the classifier 7 | # the choice of random seed is arbitrary 8 | FINAL_TRAIN_SEED = 0xAB3BDB 9 | 10 | __all__ = [ 11 | "FINAL_TRAIN_SEED", 12 | "get_bool_env_var", 13 | "hash_file", 14 | "hide_stderr", 15 | ] 16 | -------------------------------------------------------------------------------- /src/jabs/utils/utilities.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import math 3 | import os 4 | import sys 5 | from contextlib import contextmanager 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | 10 | 11 | @contextmanager 12 | def hide_stderr() -> int: 13 | """Context manager to temporarily suppress output to standard error (stderr). 14 | 15 | Redirects all output sent to stderr to os.devnull while the context is active, 16 | restoring stderr to its original state upon exit. 17 | 18 | Yields: 19 | int: The file descriptor for stderr. 20 | """ 21 | fd = sys.stderr.fileno() 22 | 23 | # copy fd before it is overwritten 24 | with os.fdopen(os.dup(fd), "wb") as copied: 25 | sys.stderr.flush() 26 | 27 | # open destination 28 | with open(os.devnull, "wb") as fout: 29 | os.dup2(fout.fileno(), fd) 30 | try: 31 | yield fd 32 | finally: 33 | # restore stderr to its previous value 34 | sys.stderr.flush() 35 | os.dup2(copied.fileno(), fd) 36 | 37 | 38 | def rolling_window(a, window, step_size=1): 39 | """Creates a rolling window view of a 1D numpy array. 40 | 41 | Generates a view of the input array with overlapping windows of the specified size, 42 | optionally with a custom step size between windows. 43 | 44 | Args: 45 | a (np.ndarray): Input 1D array. 46 | window (int): Size of each rolling window. 47 | step_size (int, optional): Step size between windows. Defaults to 1. 48 | 49 | Returns: 50 | np.ndarray: A 2D array where each row is a windowed view of the input. 51 | 52 | Raises: 53 | ValueError: If the window size is larger than the input array. 54 | """ 55 | shape = a.shape[:-1] + (a.shape[-1] - window + 1 - step_size + 1, window) 56 | strides = (*a.strides, a.strides[-1] * step_size) 57 | return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) 58 | 59 | 60 | def smooth(vec, smoothing_window): 61 | """Smooths a 1D numpy array using a moving average with edge padding. 62 | 63 | Pads the input vector at both ends with its edge values, then applies a moving average 64 | of the specified window size. The window size must be odd. 65 | 66 | Args: 67 | vec (np.ndarray): Input 1D array to smooth. 68 | smoothing_window (int): Size of the moving average window (must be odd). 69 | 70 | Returns: 71 | np.ndarray: Smoothed array as float values. 72 | 73 | Raises: 74 | AssertionError: If `smoothing_window` is not odd. 75 | """ 76 | if smoothing_window <= 1 or len(vec) == 0: 77 | return vec.astype(np.float) 78 | else: 79 | assert smoothing_window % 2 == 1, "expected smoothing_window to be odd" 80 | half_conv_len = smoothing_window // 2 81 | smooth_tgt = np.concatenate( 82 | [ 83 | np.full(half_conv_len, vec[0], dtype=vec.dtype), 84 | vec, 85 | np.full(half_conv_len, vec[-1], dtype=vec.dtype), 86 | ] 87 | ) 88 | 89 | smoothing_val = 1 / smoothing_window 90 | conv_arr = np.full(smoothing_window, smoothing_val) 91 | 92 | return np.convolve(smooth_tgt, conv_arr, mode="valid") 93 | 94 | 95 | def n_choose_r(n, r): 96 | """compute number of unique selections (disregarding order) of r items from a set of n items 97 | 98 | Args: 99 | n: number of elements to select from 100 | r: number of elements to select 101 | 102 | Returns: 103 | total number of combinations disregarding order 104 | """ 105 | return math.factorial(n) // (math.factorial(r) * math.factorial(n - r)) 106 | 107 | 108 | def hash_file(file: Path): 109 | """return hash""" 110 | chunk_size = 8192 111 | with file.open("rb") as f: 112 | h = hashlib.blake2b(digest_size=20) 113 | c = f.read(chunk_size) 114 | while c: 115 | h.update(c) 116 | c = f.read(chunk_size) 117 | return h.hexdigest() 118 | 119 | 120 | def get_bool_env_var(var_name, default_value=False) -> bool: 121 | """Gets a boolean value from an environment variable. 122 | 123 | Args: 124 | var_name: The name of the environment variable. 125 | default_value: The default value to return if the variable is 126 | not set or invalid. 127 | 128 | Returns: 129 | A boolean value. 130 | """ 131 | value = os.getenv(var_name) 132 | if value is None: 133 | return default_value 134 | 135 | return value.lower() in ("true", "1", "yes", "on", "y", "t") 136 | -------------------------------------------------------------------------------- /src/jabs/version/__init__.py: -------------------------------------------------------------------------------- 1 | """jabs version""" 2 | 3 | import importlib.metadata 4 | from pathlib import Path 5 | 6 | import toml 7 | 8 | 9 | def version_str() -> str: 10 | """Return version string from package metadata or pyproject.toml. 11 | 12 | If jabs-behavior-classifier is an installed package, gets the version from the package metadata. If not installed, 13 | attempts to read the project's pyproject.toml file to get the version. Returns 'dev' if it's not able to determine 14 | the version using either rof these methods. 15 | """ 16 | try: 17 | return importlib.metadata.version("jabs-behavior-classifier") 18 | except importlib.metadata.PackageNotFoundError: 19 | pyproject_file = Path(__file__).parent.parent.parent.parent / "pyproject.toml" 20 | try: 21 | data = toml.load(pyproject_file) 22 | return data["tool"]["poetry"]["version"] 23 | except (FileNotFoundError, KeyError, toml.TomlDecodeError): 24 | return "dev" 25 | -------------------------------------------------------------------------------- /src/jabs/video_reader/__init__.py: -------------------------------------------------------------------------------- 1 | """video reader 2 | 3 | This package handles reading frames from a video file as well as applying various 4 | annotations (such as pose overlay, animal track (trajectory), segmentation, etc.) 5 | """ 6 | 7 | from .frame_annotation import ( 8 | draw_track, 9 | label_identity, 10 | overlay_landmarks, 11 | overlay_pose, 12 | overlay_segmentation, 13 | ) 14 | from .video_reader import VideoReader 15 | 16 | __all__ = [ 17 | "VideoReader", 18 | "draw_track", 19 | "label_identity", 20 | "overlay_landmarks", 21 | "overlay_pose", 22 | "overlay_segmentation", 23 | ] 24 | -------------------------------------------------------------------------------- /src/jabs/video_reader/utilities.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def get_frame_count(video_path: str): 5 | """Get the number of frames in a video file. 6 | 7 | Args: 8 | video_path: string containing path to video file 9 | 10 | Returns: 11 | Integer number of frames in video. 12 | 13 | Raises: 14 | OSError: if unable to open specified video 15 | """ 16 | # open video file 17 | stream = cv2.VideoCapture(video_path) 18 | if not stream.isOpened(): 19 | raise OSError(f"unable to open {video_path}") 20 | 21 | return int(stream.get(cv2.CAP_PROP_FRAME_COUNT)) 22 | 23 | 24 | def get_fps(video_path: str): 25 | """get the frames per second from a video file""" 26 | # open video file 27 | stream = cv2.VideoCapture(video_path) 28 | if not stream.isOpened(): 29 | raise OSError(f"unable to open {video_path}") 30 | 31 | return round(stream.get(cv2.CAP_PROP_FPS)) 32 | -------------------------------------------------------------------------------- /src/jabs/video_reader/video_reader.py: -------------------------------------------------------------------------------- 1 | import time 2 | import typing 3 | from pathlib import Path 4 | 5 | import cv2 6 | 7 | 8 | class VideoReader: 9 | """VideoReader. 10 | 11 | Uses OpenCV to open a video file and read frames. 12 | """ 13 | 14 | _EOF: typing.ClassVar[dict] = {"data": None, "index": -1} 15 | 16 | def __init__(self, path: Path): 17 | """Initialize a VideoReader object. 18 | 19 | Args: 20 | path: path to video file 21 | """ 22 | # open video file 23 | self.stream = cv2.VideoCapture(str(path)) 24 | if not self.stream.isOpened(): 25 | raise OSError(f"unable to open {path}") 26 | 27 | self._frame_index = 0 28 | self._num_frames = int(self.stream.get(cv2.CAP_PROP_FRAME_COUNT)) 29 | 30 | # get frame rate 31 | self._fps = round(self.stream.get(cv2.CAP_PROP_FPS)) 32 | 33 | # calculate duration in seconds of each frame based on frame rate 34 | self._duration = 1.0 / self._fps 35 | 36 | # get frame dimensions 37 | self._width = int(self.stream.get(cv2.CAP_PROP_FRAME_WIDTH)) 38 | self._height = int(self.stream.get(cv2.CAP_PROP_FRAME_HEIGHT)) 39 | 40 | self._filename = path.name 41 | 42 | @property 43 | def num_frames(self): 44 | """get total number of frames in the video""" 45 | return self._num_frames 46 | 47 | @property 48 | def fps(self): 49 | """get frames per second from video""" 50 | return self._fps 51 | 52 | @property 53 | def dimensions(self): 54 | """return width, height of video frames""" 55 | return self._width, self._height 56 | 57 | @property 58 | def filename(self): 59 | """return the name of the video file""" 60 | return self._filename 61 | 62 | def get_frame_time(self, frame_number): 63 | """return a formatted string of the time of a given frame""" 64 | return time.strftime("%H:%M:%S", time.gmtime(frame_number * self._duration)) 65 | 66 | def seek(self, index): 67 | """Seek to a specific frame. 68 | 69 | This will clear the buffer and insert the frame at the new position. 70 | 71 | Note: 72 | some video formats might not be able to seek to an exact frame 73 | position so this could be slow in those cases. Our avi files have 74 | reasonable seek times. 75 | """ 76 | if self.stream.set(cv2.CAP_PROP_POS_FRAMES, index): 77 | self._frame_index = index 78 | 79 | def load_next_frame(self) -> dict: 80 | """grab the next frame from the file""" 81 | (grabbed, frame) = self.stream.read() 82 | if grabbed: 83 | data = { 84 | "data": frame, 85 | "index": self._frame_index, 86 | "duration": self._duration, 87 | } 88 | self._frame_index += 1 89 | else: 90 | data = self._EOF 91 | return data 92 | 93 | @staticmethod 94 | def _resize_image(image, width=None, height=None, interpolation=None): 95 | """resize an image, allow passing only desired width or height to maintain current aspect ratio 96 | 97 | Args: 98 | image: image to resize 99 | width: new width, if None compute to maintain aspect ratio 100 | height: new height, if None compute to maintain aspect ratio 101 | interpolation: type of interpolation to use for resize. If 102 | None, we will default to cv2.INTER_AREA for shrinking cv2.INTER_CUBIC when 103 | expanding 104 | 105 | Returns: 106 | resized image 107 | """ 108 | # current size 109 | (h, w) = image.shape[:2] 110 | 111 | # if both the width and height are None, then return the 112 | # original image 113 | if width is None and height is None: 114 | return image 115 | 116 | if width is None: 117 | # calculate the ratio of the height and construct the 118 | # dimensions 119 | r = height / float(h) 120 | dim = (int(w * r), height) 121 | 122 | elif height is None: 123 | # calculate the ratio of the width and construct the 124 | # dimensions 125 | r = width / float(w) 126 | dim = (width, int(h * r)) 127 | 128 | else: 129 | dim = (width, height) 130 | 131 | if interpolation is None: 132 | inter = cv2.INTER_AREA if dim[0] * dim[1] < w * h else cv2.INTER_CUBIC 133 | else: 134 | inter = interpolation 135 | 136 | # resize the image 137 | resized = cv2.resize(image, dim, interpolation=inter) 138 | 139 | # return the resized image 140 | return resized 141 | 142 | @classmethod 143 | def get_nframes_from_file(cls, path: Path): 144 | """get the number of frames by inspecting the video file""" 145 | # open video file 146 | stream = cv2.VideoCapture(str(path)) 147 | if not stream.isOpened(): 148 | raise OSError(f"unable to open {path}") 149 | 150 | return int(stream.get(cv2.CAP_PROP_FRAME_COUNT)) 151 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/__init__.py -------------------------------------------------------------------------------- /tests/data/identity_with_no_data_pose_est_v3.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/data/identity_with_no_data_pose_est_v3.h5.gz -------------------------------------------------------------------------------- /tests/data/readme.txt: -------------------------------------------------------------------------------- 1 | This directory contains data files needed to run some of the unit tests. 2 | 3 | Note: 4 | Unless Git Large File Support is enabled, keep the data file sizes fairly small to avoid repository bloat. -------------------------------------------------------------------------------- /tests/data/sample_pose_est_v2.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/data/sample_pose_est_v2.h5.gz -------------------------------------------------------------------------------- /tests/data/sample_pose_est_v3.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/data/sample_pose_est_v3.h5.gz -------------------------------------------------------------------------------- /tests/data/sample_pose_est_v4.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/data/sample_pose_est_v4.h5.gz -------------------------------------------------------------------------------- /tests/data/sample_pose_est_v5.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/data/sample_pose_est_v5.h5.gz -------------------------------------------------------------------------------- /tests/data/sample_pose_est_v6.h5.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/data/sample_pose_est_v6.h5.gz -------------------------------------------------------------------------------- /tests/feature_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/feature_modules/__init__.py -------------------------------------------------------------------------------- /tests/feature_modules/base.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import shutil 3 | import tempfile 4 | import unittest 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | 9 | import jabs.pose_estimation 10 | 11 | 12 | class TestFeatureBase(unittest.TestCase): 13 | _tmpdir = None 14 | _test_file = Path(__file__).parent.parent / "data" / "sample_pose_est_v5.h5.gz" 15 | 16 | @classmethod 17 | def setUpClass(cls) -> None: 18 | cls._tmpdir = tempfile.TemporaryDirectory() 19 | cls._tmpdir_path = Path(cls._tmpdir.name) 20 | 21 | pose_path = cls._tmpdir_path / "sample_pose_est_v5.h5" 22 | 23 | # decompress pose file into tempdir 24 | with gzip.open(cls._test_file, "rb") as f_in: 25 | with open(pose_path, "wb") as f_out: 26 | shutil.copyfileobj(f_in, f_out) 27 | 28 | cls._pose_est_v5 = jabs.pose_estimation.open_pose_file( 29 | cls._tmpdir_path / "sample_pose_est_v5.h5" 30 | ) 31 | 32 | # V5 pose file in the data directory does not currently have "lixit" 33 | # as one of its static objects, so we'll manually add it 34 | cls._pose_est_v5.static_objects["lixit"] = np.asarray( 35 | [[62, 166]], dtype=np.uint16 36 | ) 37 | 38 | # V5 pose file also doesn't have the food hopper static object 39 | cls._pose_est_v5.static_objects["food_hopper"] = np.asarray( 40 | [[7, 291], [7, 528], [44, 296], [44, 518]] 41 | ) 42 | 43 | @classmethod 44 | def tearDownClass(cls) -> None: 45 | cls._tmpdir.cleanup() 46 | -------------------------------------------------------------------------------- /tests/feature_modules/test_corner_features.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import src.jabs.feature_extraction.landmark_features.corner as corner_module 4 | from tests.feature_modules.base import TestFeatureBase 5 | 6 | 7 | class TestCornerFeatures(TestFeatureBase): 8 | def test_compute_corner_distances(self): 9 | pixel_scale = self._pose_est_v5.cm_per_pixel 10 | dist = corner_module.CornerDistanceInfo(self._pose_est_v5, pixel_scale) 11 | dist_to_corner = corner_module.DistanceToCorner( 12 | self._pose_est_v5, pixel_scale, dist 13 | ) 14 | bearing_to_corner = corner_module.BearingToCorner( 15 | self._pose_est_v5, pixel_scale, dist 16 | ) 17 | 18 | # check dimensions of per frame feature values 19 | for i in range(self._pose_est_v5.num_identities): 20 | dist_per_frame = dist_to_corner.per_frame(i) 21 | 22 | self.assertEqual( 23 | dist_per_frame["distance to corner"].shape, 24 | (self._pose_est_v5.num_frames,), 25 | ) 26 | 27 | bearing_per_frame = bearing_to_corner.per_frame(i) 28 | self.assertEqual( 29 | bearing_per_frame["bearing to corner"].shape, 30 | (self._pose_est_v5.num_frames,), 31 | ) 32 | 33 | # check dimensions of window feature values 34 | dist_window_values = dist_to_corner.window(i, 5, dist_per_frame) 35 | for op in dist_window_values: 36 | self.assertEqual( 37 | dist_window_values[op]["distance to corner"].shape, 38 | (self._pose_est_v5.num_frames,), 39 | ) 40 | 41 | bearing_window_values = bearing_to_corner.window(i, 5, bearing_per_frame) 42 | for op in bearing_window_values: 43 | self.assertEqual( 44 | bearing_window_values[op]["bearing to corner"].shape, 45 | (self._pose_est_v5.num_frames,), 46 | ) 47 | 48 | # check range of bearings, should be in the range [180, -180) 49 | for i in range(self._pose_est_v5.num_identities): 50 | values = bearing_to_corner.per_frame(i)["bearing to corner"] 51 | non_nan_indices = ~np.isnan(values) 52 | self.assertTrue( 53 | ( 54 | (values[non_nan_indices] <= 180) & (values[non_nan_indices] > -180) 55 | ).all() 56 | ) 57 | 58 | # check distances are >= 0 59 | for i in range(self._pose_est_v5.num_identities): 60 | values = dist_to_corner.per_frame(i)["distance to corner"] 61 | non_nan_indices = ~np.isnan(values) 62 | self.assertTrue((values[non_nan_indices] >= 0).all()) 63 | -------------------------------------------------------------------------------- /tests/feature_modules/test_food_hopper.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from src.jabs.pose_estimation import PoseEstimation 5 | from src.jabs.feature_extraction.landmark_features.food_hopper import FoodHopper 6 | from src.jabs.feature_extraction.landmark_features.food_hopper import _EXCLUDED_POINTS 7 | from tests.feature_modules.base import TestFeatureBase 8 | 9 | 10 | class TestFoodHopper(TestFeatureBase): 11 | @classmethod 12 | def setUpClass(cls) -> None: 13 | super().setUpClass() 14 | 15 | pixel_scale = cls._pose_est_v5.cm_per_pixel 16 | cls.food_hopper_feature = FoodHopper(cls._pose_est_v5, pixel_scale) 17 | 18 | def test_dimensions(self) -> None: 19 | # check dimensions of per frame feature values 20 | for i in range(self._pose_est_v5.num_identities): 21 | values = self.food_hopper_feature.per_frame(i) 22 | 23 | # TODO check dimensions of all key points, not just for NOSE 24 | self.assertEqual( 25 | values["food hopper NOSE"].shape, (self._pose_est_v5.num_frames,) 26 | ) 27 | 28 | # check dimensions of window feature values 29 | dist_window_values = self.food_hopper_feature.window(i, 5, values) 30 | for op in dist_window_values: 31 | self.assertEqual( 32 | dist_window_values[op]["food hopper NOSE"].shape, 33 | (self._pose_est_v5.num_frames,), 34 | ) 35 | 36 | def test_signed_dist(self) -> None: 37 | values = self.food_hopper_feature.per_frame(0) 38 | 39 | # perform a couple manual computations of signed distance and check 40 | hopper = self._pose_est_v5.static_objects["food_hopper"] 41 | if self._pose_est_v5.cm_per_pixel is not None: 42 | hopper = hopper * self._pose_est_v5.cm_per_pixel 43 | # swap the point x,y values and change dtype to float32 for open cv 44 | hopper_pts = hopper[:, [1, 0]].astype(np.float32) 45 | 46 | points, _ = self._pose_est_v5.get_identity_poses( 47 | 0, self._pose_est_v5.cm_per_pixel 48 | ) 49 | 50 | for key_point in PoseEstimation.KeypointIndex: 51 | # skip over the key points we don't care about 52 | if key_point in _EXCLUDED_POINTS: 53 | continue 54 | 55 | # swap our x,y to match the opencv coordinate space 56 | pts = points[:, key_point.value, [1, 0]] 57 | 58 | # check values for this keypoint for a few different frames 59 | for i in [5, 10, 50, 100, 200, 500, 1000]: 60 | signed_dist = cv2.pointPolygonTest( 61 | hopper_pts, (pts[i, 0], pts[i, 1]), True 62 | ) 63 | if np.isnan(pts[i, 0]): 64 | signed_dist = np.nan 65 | 66 | if not np.isnan(signed_dist): 67 | self.assertAlmostEqual( 68 | signed_dist, values[f"food hopper {key_point.name}"][i] 69 | ) 70 | else: 71 | self.assertTrue( 72 | np.isnan(values[f"food hopper {key_point.name}"][i]) 73 | ) 74 | 75 | def test_frame_out_of_range(self) -> None: 76 | with self.assertRaises(IndexError): 77 | _ = self.food_hopper_feature.per_frame(0)["food hopper NOSE"][100000] 78 | 79 | def test_identity_out_of_range(self) -> None: 80 | with self.assertRaises(IndexError): 81 | _ = self.food_hopper_feature.per_frame(100)[0] 82 | -------------------------------------------------------------------------------- /tests/feature_modules/test_lixit_distance.py: -------------------------------------------------------------------------------- 1 | # TODO these tests need to be fixed, they were broken during a change in how features were stored/retrieved 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | import src.jabs.feature_extraction.landmark_features.lixit as lixit 7 | from tests.feature_modules.base import TestFeatureBase 8 | 9 | 10 | class TestCornerFeatures(TestFeatureBase): 11 | @classmethod 12 | def setUpClass(cls) -> None: 13 | super().setUpClass() 14 | 15 | pixel_scale = cls._pose_est_v5.cm_per_pixel 16 | cls.distance_info = lixit.LixitDistanceInfo(cls._pose_est_v5, pixel_scale) 17 | cls.lixit_distance = lixit.DistanceToLixit( 18 | cls._pose_est_v5, pixel_scale, cls.distance_info 19 | ) 20 | 21 | @unittest.skip("") 22 | def test_dimensions(self): 23 | # check dimensions of per frame feature values 24 | for i in range(self._pose_est_v5.num_identities): 25 | distances = self.lixit_distance.per_frame(i) 26 | self.assertEqual(distances.shape, (self._pose_est_v5.num_frames,)) 27 | 28 | # check dimensions of window feature values 29 | dist_window_values = self.lixit_distance.window(i, 5, distances) 30 | for op in dist_window_values: 31 | self.assertEqual( 32 | dist_window_values[op].shape, (self._pose_est_v5.num_frames,) 33 | ) 34 | 35 | @unittest.skip("") 36 | def test_distances_greater_equal_zero(self): 37 | for i in range(self._pose_est_v5.num_identities): 38 | distances = self.lixit_distance.per_frame(i) 39 | # check distances are >= 0 40 | self.assertTrue((distances >= 0).all()) 41 | 42 | @unittest.skip("") 43 | def test_computation(self): 44 | # spot check some distance values for identity 0 45 | expected = np.asarray( 46 | [ 47 | 10.44161892, 48 | 10.51272678, 49 | 10.60188961, 50 | 10.67293644, 51 | 10.7262249, 52 | 10.70890999, 53 | 10.70890999, 54 | 10.7802515, 55 | 10.7802515, 56 | 10.90649891, 57 | ], 58 | dtype=np.float32, 59 | ) 60 | actual = self.lixit_distance.per_frame(0) 61 | print(actual) 62 | for i in range(expected.shape[0]): 63 | self.assertAlmostEqual(expected[i], actual[i]) 64 | -------------------------------------------------------------------------------- /tests/feature_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KumarLabJax/JABS-behavior-classifier/4749d378a18e51eed91eea04d0f002472b8abcc2/tests/feature_tests/__init__.py -------------------------------------------------------------------------------- /tests/feature_tests/seg_test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the base class SegTest. This class can be inherited by tests written that include segmentation 3 | data in the pose file (v6+). 4 | """ 5 | 6 | import h5py 7 | import os 8 | from pathlib import Path 9 | import tempfile 10 | import shutil 11 | import gzip 12 | 13 | from src.jabs.feature_extraction.segmentation_features import SegmentationFeatureGroup 14 | import src.jabs.pose_estimation as pose_est 15 | 16 | 17 | class SegDataBaseClass(object): 18 | """Common setup and teardown for segmentation tests.""" 19 | 20 | pixel_scale = 1.0 21 | dataPath = Path(__file__).parent / "../data" 22 | dataFileName = "sample_pose_est_v6.h5.gz" 23 | _tmpdir = None 24 | 25 | @classmethod 26 | def setUpClass(cls) -> None: 27 | cls._tmpdir = tempfile.TemporaryDirectory() 28 | cls._tmpdir_path = Path(cls._tmpdir.name) 29 | 30 | with gzip.open(cls.dataPath / cls.dataFileName, "rb") as f_in: 31 | with open( 32 | cls._tmpdir_path / cls.dataFileName.replace(".h5.gz", ".h5"), "wb" 33 | ) as f_out: 34 | shutil.copyfileobj(f_in, f_out) 35 | 36 | cls._pose_est_v6 = pose_est.open_pose_file( 37 | cls._tmpdir_path / cls.dataFileName.replace(".h5.gz", ".h5") 38 | ) 39 | 40 | cls._moment_cache = SegmentationFeatureGroup(cls._pose_est_v6, cls.pixel_scale) 41 | cls.feature_mods = cls._moment_cache._init_feature_mods(1) 42 | 43 | @classmethod 44 | def tearDown(cls): 45 | if cls._tmpdir: 46 | cls._tmpdir.cleanup() 47 | 48 | 49 | def setUpModule(): 50 | """Use if code should be executed once for all tests.""" 51 | pass 52 | -------------------------------------------------------------------------------- /tests/feature_tests/test_hu_moments.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | # project imports 4 | from .seg_test_utils import SegDataBaseClass as SBC 5 | from src.jabs.feature_extraction.segmentation_features import HuMoments 6 | 7 | 8 | class Test(SBC, unittest.TestCase): 9 | """This test will provide testing coverage for the HuMoments Feature class.""" 10 | 11 | def testHuMomentFeatureName(self) -> None: 12 | """Test HuMoment class.""" 13 | 14 | # test that data was read and setup correctly 15 | huMomentsFeature = self.feature_mods["hu_moments"] 16 | 17 | assert huMomentsFeature._feature_names[-2] == "hu6" 18 | 19 | i = 0 20 | 21 | huMoments_by_frame = huMomentsFeature.per_frame(i) 22 | 23 | assert len(huMoments_by_frame) == 7 24 | -------------------------------------------------------------------------------- /tests/feature_tests/test_moments.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import gzip 3 | import numpy as np 4 | from pathlib import Path 5 | import tempfile 6 | import shutil 7 | 8 | # project imports 9 | from src.jabs.feature_extraction.segmentation_features import ( 10 | SegmentationFeatureGroup, 11 | Moments, 12 | ) 13 | from .seg_test_utils import SegDataBaseClass as SBC 14 | import src.jabs.pose_estimation as pose_est 15 | 16 | 17 | class TestImportSrc(unittest.TestCase): 18 | @unittest.skip("") 19 | def test(self): 20 | """ 21 | # error: FileNotFoundError: Could not find module '