├── .bumpversion.cfg ├── .cookiecutter.json ├── .coveragerc ├── .editorconfig ├── .gitignore ├── .travis.yml ├── AUTHORS.rst ├── CONTRIBUTING.md ├── HISTORY.rst ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── allensdk ├── __init__.py └── eye_tracking │ ├── __init__.py │ ├── __main__.py │ ├── _schemas.py │ ├── eye_tracking.py │ ├── feature_extraction.py │ ├── fit_ellipse.py │ ├── frame_stream.py │ ├── plotting.py │ ├── ransac.py │ ├── ui │ ├── __init__.py │ ├── __main__.py │ └── qt.py │ └── utils.py ├── appveyor.yml ├── docs ├── Makefile ├── _static │ └── .gitkeep ├── allensdk.eye_tracking.rst ├── allensdk.eye_tracking.ui.rst ├── allensdk.rst ├── authors.rst ├── conf.py ├── history.rst ├── index.rst ├── installation.rst ├── make.bat ├── modules.rst └── usage.rst ├── requirements.txt ├── requirements_dev.txt ├── setup.cfg ├── setup.py ├── test ├── test_eye_tracking.py ├── test_feature_extraction.py ├── test_fit_ellipse.py ├── test_frame_stream.py ├── test_module.py ├── test_plotting.py ├── test_qt_ui.py ├── test_ransac.py └── test_utils.py └── test_requirements.txt /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.2.1 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:allensdk/eye_tracking/__init__.py] 7 | search = __version__ = '{current_version}' 8 | replace = __version__ = '{new_version}' 9 | 10 | [bumpversion:file:setup.cfg] 11 | search = version = {current_version} 12 | replace = version = {new_version} 13 | 14 | [bumpversion:file:setup.py] 15 | search = version='{current_version}' 16 | replace = version='{new_version}' 17 | 18 | -------------------------------------------------------------------------------- /.cookiecutter.json: -------------------------------------------------------------------------------- 1 | { 2 | "email": "jedp@alleninstitute.org", 3 | "full_name": "Jed Perkins", 4 | "open_source_license": "Allen Institute Software License", 5 | "project_name": "AllenSDK Eye Tracking", 6 | "project_namespace": "allensdk", 7 | "project_short_description": "Allen Institute package for mouse eye tracking.", 8 | "project_slug": "eye_tracking", 9 | "repo_url": "https://github.com/AllenInstitute/allensdk.eye_tracking", 10 | "user_name": "jedp", 11 | "version": "0.1.0" 12 | } -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = ./ 4 | 5 | [report] 6 | exclude_lines = 7 | if __name__ == .__main__.: 8 | ignore_errors = True 9 | omit = 10 | test/* 11 | setup.py -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | .pytest_cache/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | 56 | # Sphinx documentation 57 | docs/_build/ 58 | 59 | # PyBuilder 60 | target/ 61 | 62 | # pyenv python configuration file 63 | .python-version 64 | 65 | # IDE Specific 66 | .idea 67 | .vscode* -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | matrix: 2 | include: 3 | - os: linux 4 | sudo: required 5 | python: 2.7 6 | env: TEST_PYTHON_VERSION=2.7 7 | - os: linux 8 | sudo: required 9 | python: 3.6 10 | env: TEST_PYTHON_VERSION=3.6 11 | - os: osx 12 | language: generic 13 | env: TEST_PYTHON_VERSION=2.7 14 | - os: osx 15 | language: generic 16 | env: TEST_PYTHON_VERSION=3.6 17 | 18 | install: 19 | - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then 20 | sudo apt-get update; 21 | if [[ "$TEST_PYTHON_VERSION" == "2.7" ]]; then 22 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 23 | else 24 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 25 | fi 26 | else 27 | brew update; 28 | if [[ "$TEST_PYTHON_VERSION" == "2.7" ]]; then 29 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-MacOSX-x86_64.sh -O miniconda.sh; 30 | else 31 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh; 32 | fi 33 | fi 34 | - bash miniconda.sh -b -p $HOME/miniconda 35 | - export PATH="$HOME/miniconda/bin:$PATH" 36 | - hash -r 37 | - conda config --set always_yes yes --set changeps1 no 38 | - conda update -q conda 39 | - conda create -q -n test-environment python=$TEST_PYTHON_VERSION pip 40 | - source activate test-environment 41 | - conda install -c conda-forge xvfbwrapper 42 | - conda install -c conda-forge opencv=3.3.0 43 | - conda install -c conda-forge pyqt 44 | - pip install codecov 45 | - pip install -r test_requirements.txt 46 | - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then 47 | pip install pytest-xvfb; 48 | fi 49 | - pip install . 50 | 51 | script: 52 | - coverage run --source ./ -m pytest 53 | - codecov -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | Credits 3 | ======= 4 | 5 | Development Lead 6 | ---------------- 7 | 8 | * Jed Perkins `@JFPerkins `_ 9 | 10 | Contributors 11 | ------------ 12 | 13 | Initial contributions to internal eye tracking system on AllenSDK: 14 | 15 | * David Feng `@dyf `_ 16 | * Michael Buice `@mabuice `_ 17 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Allen Institute Contribution Agreement 2 | 3 | This document describes the terms under which you may make “Contributions” — 4 | which may include without limitation, software additions, revisions, bug fixes, configuration changes, 5 | documentation, or any other materials — to any of the projects owned or managed by the Allen Institute. 6 | If you have questions about these terms, please contact us at terms@alleninstitute.org. 7 | 8 | You certify that: 9 | 10 | • Your Contributions are either: 11 | 12 | 1. Created in whole or in part by you and you have the right to submit them under the designated license 13 | (described below); or 14 | 2. Based upon previous work that, to the best of your knowledge, is covered under an appropriate 15 | open source license and you have the right under that license to submit that work with modifications, 16 | whether created in whole or in part by you, under the designated license; or 17 | 18 | 3. Provided directly to you by some other person who certified (1) or (2) and you have not modified them. 19 | 20 | • You are granting your Contributions to the Allen Institute under the terms of the [2-Clause BSD license](https://opensource.org/licenses/BSD-2-Clause) 21 | (the “designated license”). 22 | 23 | • You understand and agree that the Allen Institute projects and your Contributions are public and that 24 | a record of the Contributions (including all metadata and personal information you submit with them) is 25 | maintained indefinitely and may be redistributed consistent with the Allen Institute’s mission and the 26 | 2-Clause BSD license. 27 | -------------------------------------------------------------------------------- /HISTORY.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | History 3 | ======= 4 | 5 | 1.2.1 (2018-08-23) 6 | ------------------ 7 | * Update defaults based on internal user overrides for Allen Institute videos. 8 | * Remove generic threshold_factor and threshold_pixels settings for starburst. 9 | * Fix argument parsing to not error out if nested values aren't provided. 10 | 11 | 1.2.0 (2018-02-24) 12 | ------------------ 13 | * Rework median filtering and correlation functions to use OpenCV instead of 14 | scipy for performance improvements. 15 | * Fix seed point finding method to use template matching to improve point finding 16 | with the default bounding boxes. 17 | * Add keyword arguments to filters for candidate points. 18 | * Expose additional input parameters: average_iris_intensity, clip_pupil_values, 19 | and max_eccentricity. 20 | * Add constraints to EllipseFitter, preventing any ellipse axis longer than the 21 | index ray length as well as limiting eccentricity to below max_eccentricity. 22 | * Use the keyword arguments for candidate point filters to filter rays where a 23 | the baseline intensity is out of pupil limits if clip_pupil_values is set. 24 | * Add plot of average pupil intensity to QC output to check behavior of adaptive 25 | pupil tracking. 26 | * Add plot of best fit error to QC output. 27 | * Add UI for testing configuration parameters and generating input jsons. 28 | 29 | 1.1.1 (2018-02-13) 30 | ------------------ 31 | * Expose median kernel smoothing to the command line. 32 | * Add seed point and candidate pupil points to annotation output. 33 | 34 | 1.1.0 (2018-02-11) 35 | ------------------ 36 | * Add frame iteration to allow processing subsets of movies. Also 37 | add bounding box image to QC output. 38 | 39 | 1.0.0 (2018-02-07) 40 | ------------------ 41 | * Rename from aibs.eye_tracking to allensdk.eye_tracking. 42 | 43 | 0.2.3 (2018-02-06) 44 | ------------------ 45 | * Add options to set cr_threshold_factor, cr_threshold_pixels, pupil_threshold_factor, 46 | pupil_threshold_pixels in the starburst parameters. They will override the 47 | default threshold_factor and threshold_pixels if set. 48 | * Add option to turn off adaptive pupil shade tracking. 49 | Exposes fourcc string as parameter for annotation in case default codec is not 50 | supported or desired. 51 | 52 | 0.2.2 (2018-01-20) 53 | ------------------ 54 | * Fix matplotlib backend warning. 55 | * Show help if required argument is missing or input command is incorrect. 56 | 57 | 0.2.1 (2017-12-13) 58 | ------------------ 59 | * Fix bug preventing module running when number of frames was not specified. 60 | 61 | 0.2.0 (2017-12-11) 62 | ------------------ 63 | * Initial release of independent eye tracker. 64 | 65 | 0.1.0 (2017-10-19) 66 | ------------------ 67 | * Initial port over of eye tracking code from AllenSDK internal. 68 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Allen Institute Software License - This software license is the 2-clause BSD license 2 | plus a third clause that prohibits redistribution for commercial purposes without further permission. 3 | 4 | Copyright (c) 2018. Allen Institute. All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the 7 | following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the 10 | following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the 13 | following disclaimer in the documentation and/or other materials provided with the distribution. 14 | 15 | 3. Redistributions for commercial purposes are not permitted without the Allen Institute's written permission. 16 | For purposes of this license, commercial purposes is the incorporation of the Allen Institute's software into 17 | anything for which you will charge fees or other compensation. Contact terms@alleninstitute.org for commercial 18 | licensing opportunities. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, 21 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 25 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE 26 | USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include test_requirements.txt 3 | include AUTHORS.rst 4 | include HISTORY.rst 5 | include LICENSE 6 | include README.md 7 | 8 | recursive-include tests * 9 | recursive-exclude * __pycache__ 10 | recursive-exclude * *.py[co] 11 | 12 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/AllenInstitute/allensdk.eye_tracking.svg?branch=master)](https://travis-ci.org/AllenInstitute/allensdk.eye_tracking) 2 | [![Build status](https://ci.appveyor.com/api/projects/status/spkm2kb09u70a3n5/branch/master?svg=true)](https://ci.appveyor.com/project/JFPerkins/allensdk-eye-tracking/branch/master) 3 | [![codecov](https://codecov.io/gh/AllenInstitute/allensdk.eye_tracking/branch/master/graph/badge.svg)](https://codecov.io/gh/AllenInstitute/allensdk.eye_tracking) 4 | 5 | AllenSDK Eye Tracking 6 | ===================== 7 | 8 | This is the python package the Allen Institute uses for estimating 9 | pupil position and shape from eye videos. The position of a LED 10 | reflection on the cornea is also tracked and is a required feature of 11 | the input streams. The input videos are black and white. 12 | 13 | Source: https://github.com/AllenInstitute/allensdk.eye_tracking 14 | 15 | Installation 16 | ------------ 17 | The video IO is done using OpenCV's video functionality. Unfortunately, 18 | OpenCV on pip seems to not be built with the necessary backend, as the 19 | methods fail silently. As a result, we have not included OpenCV in the 20 | requirements and it is necessary to get it seperately, built with the 21 | video capture and writing functional. Additionally, on some platforms 22 | scikit-image does not build easily from source and the developers don't 23 | have binary distributions for all platforms yet. The simplest way to 24 | install these difficult dependencies is to use conda: 25 | 26 | conda install scikit-image 27 | conda install -c conda-forge opencv=3.3.0 28 | conda install -c conda-forge pyqt 29 | 30 | The version of opencv is pinned because the latest (3.4.1 as of this 31 | writing) seems to have a bug with the VideoCapture code which causes 32 | errors reading videos on linux. Latest does seem to work on Windows. 33 | The rest of the dependencies are all in the requirements, so to install 34 | just clone or download the repository and then from inside the top 35 | level directory either run: 36 | 37 | pip install . 38 | 39 | or 40 | 41 | python setup.py install 42 | 43 | Usage 44 | ----- 45 | After installing the package, and entry point is created so it can be run 46 | from the command line. To minimally run with the default settings: 47 | 48 | allensdk.eye_tracking --input_source 49 | 50 | To see all options that can be set at the command line: 51 | 52 | allensdk.eye_tracking --help 53 | 54 | There are a lot of options that can be set, so often it can be more 55 | convenient to store them in a JSON-formatted file which can be used like: 56 | 57 | allensdk.eye_tracking --input_json 58 | 59 | The input json can be combined with other command line argument, which will 60 | take precedence over anything in the json. There is a UI tool for adjusting 61 | and saving input parameters that can be used by running: 62 | 63 | allensdk.eye_tracking_ui 64 | 65 | Description of algorithm 66 | ------------------------ 67 | The general way that the algorithm works is to (for every frame): 68 | 69 | 1. Use a simple bright circle template to estimate the seed point for 70 | searching for a corneal reflection of the LED. 71 | 2. Draw rays from the seed point and find light-to-dark threshold 72 | crossings to generate estimated points for an ellipse fit. 73 | 3. Use ransac to find the best fit ellipse to the points. 74 | 4. Optionally fill in the estimated corneal reflection with the last 75 | shade of the pupil. This is necessary if the corneal reflection 76 | occludes the pupil at all. 77 | 5. Repeat steps 1-3, but with a dark circle template and dark-to-light 78 | threshold crossings to find the pupil ellipse parameters. 79 | -------------------------------------------------------------------------------- /allensdk/__init__.py: -------------------------------------------------------------------------------- 1 | __path__ = __import__('pkgutil').extend_path(__path__, __name__) 2 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Top-level package for AllenSDK Eye Tracking.""" 4 | 5 | __author__ = """Jed Perkins""" 6 | __email__ = 'jedp@alleninstitute.org' 7 | __version__ = '1.2.1' 8 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/__main__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import marshmallow 5 | from argschema import ArgSchemaParser 6 | from argschema.utils import schema_argparser 7 | import warnings 8 | import matplotlib 9 | with warnings.catch_warnings(): 10 | warnings.simplefilter("ignore") 11 | matplotlib.use("Agg") 12 | from matplotlib import pyplot as plt # noqa: E402 13 | from ._schemas import (InputParameters, OutputParameters, DEFAULT_ANNOTATION) # noqa: E402 E501 14 | from .eye_tracking import EyeTracker, PointGenerator # noqa: E402 15 | from .frame_stream import CvInputStream, CvOutputStream # noqa: E402 16 | from .plotting import (plot_summary, plot_cumulative, 17 | annotate_with_box, plot_timeseries) # noqa: E402 18 | from allensdk.eye_tracking import __version__ # noqa E402 19 | 20 | 21 | def setup_annotation(im_shape, annotate_movie, output_file, fourcc="H264"): 22 | if annotate_movie: 23 | ostream = CvOutputStream(output_file, im_shape[::-1], fourcc=fourcc) 24 | ostream.open(output_file) 25 | else: 26 | ostream = None 27 | return ostream 28 | 29 | 30 | def write_output(output_dir, cr_parameters, pupil_parameters, mean_frame): 31 | output = { 32 | "cr_parameter_file": os.path.join(output_dir, "cr_params.npy"), 33 | "pupil_parameter_file": os.path.join(output_dir, "pupil_params.npy"), 34 | "mean_frame_file": os.path.join(output_dir, "mean_frame.png"), 35 | "module_version": __version__ 36 | } 37 | plt.imsave(output["mean_frame_file"], mean_frame, cmap="gray") 38 | np.save(output["cr_parameter_file"], cr_parameters) 39 | np.save(output["pupil_parameter_file"], pupil_parameters) 40 | 41 | return output 42 | 43 | 44 | def write_QC_output(annotator, cr_parameters, pupil_parameters, 45 | cr_errors, pupil_errors, mean_frame, pupil_intensity=None, 46 | **kwargs): 47 | output_dir = kwargs.get("qc", {}).get("output_dir", kwargs["output_dir"]) 48 | annotator.annotate_with_cumulative_cr( 49 | mean_frame, os.path.join(output_dir, "cr_all.png")) 50 | annotator.annotate_with_cumulative_pupil( 51 | mean_frame, os.path.join(output_dir, "pupil_all.png")) 52 | plot_cumulative(annotator.densities["pupil"], annotator.densities["cr"], 53 | output_dir=output_dir) 54 | plot_summary(pupil_parameters, cr_parameters, output_dir=output_dir) 55 | pupil_bbox = kwargs.get("pupil_bounding_box", []) 56 | cr_bbox = kwargs.get("cr_bounding_box", []) 57 | if len(pupil_bbox) == 4: 58 | mean_frame = annotate_with_box(mean_frame, pupil_bbox, 59 | (0, 0, 255)) 60 | if len(cr_bbox) == 4: 61 | mean_frame = annotate_with_box(mean_frame, cr_bbox, 62 | (255, 0, 0)) 63 | plt.imsave(os.path.join(output_dir, "mean_frame_bbox.png"), mean_frame) 64 | plot_timeseries(pupil_errors, None, title="pupil ellipse fit errors", 65 | filename=os.path.join(output_dir, "pupil_ellipse_err.png")) 66 | plot_timeseries(cr_errors, None, title="cr ellipse fit errors", 67 | filename=os.path.join(output_dir, "cr_ellipse_err.png")) 68 | if pupil_intensity: 69 | plot_timeseries( 70 | pupil_intensity, None, title="estimated pupil intensity", 71 | filename=os.path.join(output_dir, "pupil_intensity.png")) 72 | 73 | 74 | def main(): 75 | """Main entry point for running AllenSDK Eye Tracking.""" 76 | try: 77 | mod = ArgSchemaParser(schema_type=InputParameters, 78 | output_schema_type=OutputParameters) 79 | 80 | istream = CvInputStream(mod.args["input_source"]) 81 | 82 | ostream = setup_annotation(istream.frame_shape, 83 | **mod.args.get("annotation", 84 | DEFAULT_ANNOTATION)) 85 | 86 | qc_params = mod.args.get("qc", {}) 87 | generate_plots = qc_params.get( 88 | "generate_plots", EyeTracker.DEFAULT_GENERATE_QC_OUTPUT) 89 | 90 | tracker = EyeTracker(istream, 91 | ostream, 92 | mod.args.get("starburst", {}), 93 | mod.args.get("ransac", {}), 94 | mod.args["pupil_bounding_box"], 95 | mod.args["cr_bounding_box"], 96 | generate_plots, 97 | **mod.args.get("eye_params", {})) 98 | cr_params, pupil_params, cr_err, pupil_err = tracker.process_stream( 99 | start=mod.args.get("start_frame", 0), 100 | stop=mod.args.get("stop_frame", None), 101 | step=mod.args.get("frame_step", 1) 102 | ) 103 | 104 | output = write_output(mod.args["output_dir"], cr_params, 105 | pupil_params, tracker.mean_frame) 106 | 107 | pupil_intensity = None 108 | if tracker.adaptive_pupil: 109 | pupil_intensity = tracker.pupil_colors 110 | if generate_plots: 111 | write_QC_output(tracker.annotator, cr_params, pupil_params, 112 | cr_err, pupil_err, tracker.mean_frame, 113 | pupil_intensity=pupil_intensity, **mod.args) 114 | 115 | output["input_parameters"] = mod.args 116 | if "output_json" in mod.args: 117 | mod.output(output, indent=1) 118 | else: 119 | print(json.dumps(mod.get_output_json(output), indent=1)) 120 | except marshmallow.ValidationError as e: 121 | print(e) 122 | argparser = schema_argparser(InputParameters()) 123 | argparser.print_usage() 124 | 125 | 126 | if __name__ == "__main__": 127 | main() 128 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/_schemas.py: -------------------------------------------------------------------------------- 1 | from argschema import ArgSchema 2 | from argschema.schemas import DefaultSchema 3 | from argschema.fields import (Nested, OutputDir, InputFile, Bool, Float, Int, 4 | OutputFile, NumpyArray, Str) 5 | from .eye_tracking import PointGenerator, EyeTracker 6 | from .fit_ellipse import EllipseFitter 7 | 8 | DEFAULT_ANNOTATION = {"annotate_movie": False, 9 | "output_file": "./annotated.avi"} 10 | 11 | 12 | class RansacParameters(DefaultSchema): 13 | minimum_points_for_fit = Int( 14 | default=EllipseFitter.DEFAULT_MINIMUM_POINTS_FOR_FIT, 15 | description="Number of points required to fit data") 16 | number_of_close_points = Int( 17 | default=EllipseFitter.DEFAULT_NUMBER_OF_CLOSE_POINTS, 18 | description=("Number of candidate outliers reselected as inliers " 19 | "required to consider a good fit")) 20 | threshold = Float( 21 | default=EllipseFitter.DEFAULT_THRESHOLD, 22 | description=("Error threshold below which data should be considered " 23 | "an inlier")) 24 | iterations = Int( 25 | default=EllipseFitter.DEFAULT_ITERATIONS, 26 | description="Number of iterations to run") 27 | 28 | 29 | class AnnotationParameters(DefaultSchema): 30 | annotate_movie = Bool( 31 | default=DEFAULT_ANNOTATION["annotate_movie"], 32 | description="Flag for whether or not to annotate") 33 | output_file = OutputFile(default=DEFAULT_ANNOTATION["output_file"]) 34 | fourcc = Str(description=("FOURCC string for video encoding. On Windows " 35 | "H264 is not available by default, so it will " 36 | "need to be installed or a different codec " 37 | "used.")) 38 | 39 | 40 | class StarburstParameters(DefaultSchema): 41 | index_length = Int( 42 | default=PointGenerator.DEFAULT_INDEX_LENGTH, 43 | description="Initial default length for rays") 44 | n_rays = Int( 45 | default=PointGenerator.DEFAULT_N_RAYS, 46 | description="Number of rays to draw") 47 | cr_threshold_factor = Float( 48 | default=PointGenerator.DEFAULT_THRESHOLD_FACTOR, 49 | description=("Threshold factor for corneal reflection ellipse edges, " 50 | "will supercede `threshold_factor` for corneal " 51 | "reflection if specified")) 52 | pupil_threshold_factor = Float( 53 | default=PointGenerator.DEFAULT_THRESHOLD_FACTOR, 54 | description=("Threshold factor for pupil ellipse edges, will " 55 | "supercede `threshold_factor` for pupil if specified")) 56 | cr_threshold_pixels = Int( 57 | default=PointGenerator.DEFAULT_CR_THRESHOLD_PIXELS, 58 | description=("Number of pixels from start of ray to use for adaptive " 59 | "threshold of the corneal reflection. Also serves as a " 60 | "minimum cutoff for point detection")) 61 | pupil_threshold_pixels = Int( 62 | default=PointGenerator.DEFAULT_PUPIL_THRESHOLD_PIXELS, 63 | description=("Number of pixels from start of ray to use for adaptive " 64 | "threshold of the pupil. Also serves as a minimum " 65 | "cutoff for point detection")) 66 | 67 | 68 | class EyeParameters(DefaultSchema): 69 | cr_recolor_scale_factor = Float( 70 | default=EyeTracker.DEFAULT_CR_RECOLOR_SCALE_FACTOR, 71 | description="Size multiplier for corneal reflection recolor mask") 72 | min_pupil_value = Int( 73 | default=EyeTracker.DEFAULT_MIN_PUPIL_VALUE, 74 | description="Minimum value the average pupil shade can be") 75 | max_pupil_value = Int( 76 | default=EyeTracker.DEFAULT_MAX_PUPIL_VALUE, 77 | description="Maximum value the average pupil shade can be") 78 | recolor_cr = Bool( 79 | default=EyeTracker.DEFAULT_RECOLOR_CR, 80 | description="Flag for recoloring corneal reflection") 81 | adaptive_pupil = Bool( 82 | default=EyeTracker.DEFAULT_ADAPTIVE_PUPIL, 83 | description="Flag for whether or not to adaptively update pupil color") 84 | pupil_mask_radius = Int( 85 | default=EyeTracker.DEFAULT_PUPIL_MASK_RADIUS, 86 | description="Radius of pupil mask used to find seed point") 87 | cr_mask_radius = Int( 88 | default=EyeTracker.DEFAULT_CR_MASK_RADIUS, 89 | description="Radius of cr mask used to find seed point") 90 | smoothing_kernel_size = Int( 91 | default=EyeTracker.DEFAULT_SMOOTHING_KERNEL_SIZE, 92 | description=("Kernel size for median filter smoothing kernel (must be " 93 | "odd)")) 94 | clip_pupil_values = Bool( 95 | default=EyeTracker.DEFAULT_CLIP_PUPIL_VALUES, 96 | description=("Flag of whether or not to restrict pupil values for " 97 | "starburst to fall within the range of (min_pupil_value, " 98 | "max_pupil_value)")) 99 | average_iris_intensity = Int( 100 | default=EyeTracker.DEFAULT_AVERAGE_IRIS_INTENSITY, 101 | description="Average expected intensity of the iris") 102 | max_eccentricity = Float( 103 | default=EyeTracker.DEFAULT_MAX_ECCENTRICITY, 104 | description="Maximum eccentricity allowed for pupil.") 105 | 106 | 107 | class QCParameters(DefaultSchema): 108 | generate_plots = Bool( 109 | default=EyeTracker.DEFAULT_GENERATE_QC_OUTPUT, 110 | description="Flag for whether or not to output QC plots") 111 | output_dir = OutputDir( 112 | default="./", 113 | description="Folder to store QC outputs") 114 | 115 | 116 | class InputParameters(ArgSchema): 117 | output_dir = OutputDir( 118 | default="./", 119 | description="Directory in which to store data output files") 120 | input_source = InputFile( 121 | description="Path to input movie", 122 | required=True) 123 | pupil_bounding_box = NumpyArray(dtype="int", default=[]) 124 | cr_bounding_box = NumpyArray(dtype="int", default=[]) 125 | start_frame = Int( 126 | description="Frame of movie to start processing at") 127 | stop_frame = Int( 128 | description="Frame of movie to end processing at") 129 | frame_step = Int( 130 | description=("Interval of frames to process. Used for skipping frames," 131 | "if 1 it will process every frame between start and stop") 132 | ) 133 | ransac = Nested(RansacParameters) 134 | annotation = Nested(AnnotationParameters) 135 | starburst = Nested(StarburstParameters) 136 | eye_params = Nested(EyeParameters) 137 | qc = Nested(QCParameters) 138 | 139 | 140 | class OutputSchema(DefaultSchema): 141 | input_parameters = Nested( 142 | InputParameters, 143 | description="Input parameters the module was run with", 144 | required=True) 145 | 146 | 147 | class OutputParameters(OutputSchema): 148 | cr_parameter_file = OutputFile(required=True) 149 | pupil_parameter_file = OutputFile(required=True) 150 | mean_frame_file = OutputFile(required=True) 151 | module_version = Str(required=True) 152 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/eye_tracking.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import cv2 3 | import numpy as np 4 | from .fit_ellipse import EllipseFitter, ellipse_pass_filter 5 | from .utils import generate_ray_indices, get_ray_values 6 | from .feature_extraction import (get_circle_template, 7 | max_correlation_positions) 8 | from .plotting import Annotator, ellipse_points 9 | 10 | 11 | class PointGenerator(object): 12 | """Class to find candidate points for ellipse fitting. 13 | 14 | Candidates points are found by drawing rays from a seed point and 15 | checking for the first threshold crossing of each ray. 16 | 17 | Parameters 18 | ---------- 19 | index_length : int 20 | Initial default length for ray indices. 21 | n_rays : int 22 | The number of rays to check. 23 | cr_threshold_factor : float 24 | Multiplicative factor for thresholding corneal reflection. 25 | pupil_threshold_factor : float 26 | Multiplicative factor for thresholding pupil. 27 | cr_threshold_pixels : int 28 | Number of pixels (from beginning of ray) to use to determine 29 | threshold of corneal reflection. 30 | pupil_threshold_pixels : int 31 | Number of pixels (from beginning of ray) to use to determine 32 | threshold of pupil. 33 | """ 34 | DEFAULT_INDEX_LENGTH = 150 35 | DEFAULT_N_RAYS = 150 36 | DEFAULT_THRESHOLD_FACTOR = 1.2 37 | DEFAULT_CR_THRESHOLD_PIXELS = 10 38 | DEFAULT_PUPIL_THRESHOLD_PIXELS = 22 39 | 40 | def __init__(self, index_length=DEFAULT_INDEX_LENGTH, 41 | n_rays=DEFAULT_N_RAYS, 42 | cr_threshold_factor=DEFAULT_THRESHOLD_FACTOR, 43 | pupil_threshold_factor=DEFAULT_THRESHOLD_FACTOR, 44 | cr_threshold_pixels=DEFAULT_CR_THRESHOLD_PIXELS, 45 | pupil_threshold_pixels=DEFAULT_PUPIL_THRESHOLD_PIXELS): 46 | self.update_params(index_length=index_length, n_rays=n_rays, 47 | cr_threshold_factor=cr_threshold_factor, 48 | pupil_threshold_factor=pupil_threshold_factor, 49 | cr_threshold_pixels=cr_threshold_pixels, 50 | pupil_threshold_pixels=pupil_threshold_pixels) 51 | self.above_threshold = {"cr": False, 52 | "pupil": True} 53 | self._intensity_estimate = 0 54 | 55 | def update_params(self, index_length=DEFAULT_INDEX_LENGTH, 56 | n_rays=DEFAULT_N_RAYS, 57 | cr_threshold_factor=DEFAULT_THRESHOLD_FACTOR, 58 | pupil_threshold_factor=DEFAULT_THRESHOLD_FACTOR, 59 | cr_threshold_pixels=DEFAULT_CR_THRESHOLD_PIXELS, 60 | pupil_threshold_pixels=DEFAULT_PUPIL_THRESHOLD_PIXELS): 61 | """Update starburst point generation parameters. 62 | 63 | Parameters 64 | ---------- 65 | index_length : int 66 | Initial default length for ray indices. 67 | n_rays : int 68 | The number of rays to check. 69 | cr_threshold_factor : float 70 | Multiplicative factor for thresholding corneal reflection. 71 | pupil_threshold_factor : float 72 | Multiplicative factor for thresholding pupil. 73 | cr_threshold_pixels : int 74 | Number of pixels (from beginning of ray) to use to determine 75 | threshold of corneal reflection. 76 | pupil_threshold_pixels : int 77 | Number of pixels (from beginning of ray) to use to determine 78 | threshold of pupil. 79 | """ 80 | self.index_length = index_length 81 | self.xs, self.ys = generate_ray_indices(index_length, n_rays) 82 | self.threshold_pixels = {"cr": cr_threshold_pixels, 83 | "pupil": pupil_threshold_pixels} 84 | self.threshold_factor = {"cr": cr_threshold_factor, 85 | "pupil": pupil_threshold_factor} 86 | 87 | def get_candidate_points(self, image, seed_point, point_type, 88 | filter_function=None, filter_args=(), 89 | filter_kwargs=None): 90 | """Get candidate points for ellipse fitting. 91 | 92 | Parameters 93 | ---------- 94 | image : numpy.ndarray 95 | Image to check for threshold crossings. 96 | seed_point : tuple 97 | (y, x) center point for ray burst. 98 | point_type : str 99 | Either 'cr' or 'pupil'. Determines if threshold crossing is 100 | high-to-low or low-to-high and which `threshold_factor` and 101 | `threshold_pixels` value to use. 102 | 103 | Returns 104 | ------- 105 | candidate_points : list 106 | List of (y, x) candidate points. 107 | """ 108 | xs = self.xs + seed_point[1] 109 | ys = self.ys + seed_point[0] 110 | ray_values = get_ray_values(xs, ys, image) 111 | filtered_out = 0 112 | threshold_not_crossed = 0 113 | candidate_points = [] 114 | if filter_kwargs is None: 115 | filter_kwargs = {} 116 | for i, values in enumerate(ray_values): 117 | try: 118 | point = self.threshold_crossing( 119 | xs[i], ys[i], values, point_type) 120 | if filter_function is not None: 121 | filter_kwargs["pupil_intensity_estimate"] = \ 122 | self._intensity_estimate 123 | if filter_function(point, *filter_args, **filter_kwargs): 124 | candidate_points.append(point) 125 | else: 126 | filtered_out += 1 127 | else: 128 | candidate_points.append(point) 129 | except ValueError: 130 | threshold_not_crossed += 1 131 | if threshold_not_crossed or filtered_out: 132 | logging.debug(("%s candidate points returned, %s filtered out, %s " 133 | "not generated because threshold not crossed"), 134 | len(candidate_points), filtered_out, 135 | threshold_not_crossed) 136 | return candidate_points 137 | 138 | def threshold_crossing(self, xs, ys, values, point_type): 139 | """Check a ray for where it crosses a threshold. 140 | 141 | The threshold is calculated using `get_threshold`. 142 | 143 | Parameters 144 | ---------- 145 | xs : numpy.ndarray 146 | X indices of ray. 147 | ys : numpy.ndarray 148 | Y indices of ray. 149 | values : numpy.ndarray 150 | Image values along ray. 151 | point_type : str 152 | Either 'cr' or 'pupil'. Determines if threshold crossing is 153 | high-to-low or low-to-high and which `threshold_factor` and 154 | `threshold_pixels` value to use. 155 | 156 | Returns 157 | ------- 158 | y_index : int 159 | Y index of threshold crossing. 160 | x_index : int 161 | X index of threshold crossing. 162 | 163 | Raises 164 | ------ 165 | ValueError 166 | If no threshold crossing found. 167 | """ 168 | try: 169 | above_threshold = self.above_threshold[point_type] 170 | threshold_pixels = self.threshold_pixels[point_type] 171 | threshold_factor = self.threshold_factor[point_type] 172 | except KeyError: 173 | raise ValueError(("'{}' is not a supported point type, must be " 174 | "'cr' or 'pupil'").format(point_type)) 175 | threshold = self.get_threshold(values, threshold_pixels, 176 | threshold_factor) 177 | if above_threshold: 178 | comparison = values[threshold_pixels:] > threshold 179 | else: 180 | comparison = values[threshold_pixels:] < threshold 181 | sub_index = np.argmax(comparison) 182 | if comparison[sub_index]: 183 | index = threshold_pixels + sub_index 184 | return ys[index], xs[index] 185 | else: 186 | raise ValueError("No value in array crosses: {}".format(threshold)) 187 | 188 | def get_threshold(self, ray_values, threshold_pixels, threshold_factor): 189 | """Calculate the threshold from the ray values. 190 | 191 | The threshold is determined from `threshold_factor` times the 192 | mean of the first `threshold_pixels` values. 193 | 194 | Parameters 195 | ---------- 196 | ray_values : numpy.ndarray 197 | Values of the ray. 198 | threshold_factor : float 199 | Multiplicative factor for thresholding. 200 | threshold_pixels : int 201 | Number of pixels (from beginning of ray) to use to determine 202 | threshold. 203 | 204 | Returns 205 | ------- 206 | threshold : float 207 | Threshold to set for candidate point. 208 | """ 209 | sub_ray = ray_values[threshold_pixels] 210 | self._intensity_estimate = np.mean(sub_ray) 211 | threshold = threshold_factor*self._intensity_estimate 212 | 213 | return threshold 214 | 215 | 216 | class EyeTracker(object): 217 | """Mouse Eye-Tracker. 218 | 219 | Parameters 220 | ---------- 221 | input_stream : generator 222 | Generator that yields numpy.ndarray frames to analyze. 223 | output_stream : stream 224 | Stream that accepts numpuy.ndarrays in the write method. None if 225 | not outputting annotations. 226 | starburst_params : dict 227 | Dictionary of keyword arguments for `PointGenerator`. 228 | ransac_params : dict 229 | Dictionary of keyword arguments for `EllipseFitter`. 230 | pupil_bounding_box : numpy.ndarray 231 | [xmin xmax ymin ymax] bounding box for pupil seed point search. 232 | cr_bounding_box : numpy.ndarray 233 | [xmin xmax ymin ymax] bounding box for cr seed point search. 234 | generate_QC_output : bool 235 | Flag to compute extra QC data on frames. 236 | **kwargs 237 | pupil_min_value : int 238 | pupil_max_value : int 239 | cr_mask_radius : int 240 | pupil_mask_radius : int 241 | cr_recolor_scale_factor : float 242 | recolor_cr : bool 243 | adaptive_pupil: bool 244 | smoothing_kernel_size : int 245 | clip_pupil_values : bool 246 | average_iris_intensity : int 247 | """ 248 | DEFAULT_MIN_PUPIL_VALUE = 0 249 | DEFAULT_MAX_PUPIL_VALUE = 40 250 | DEFAULT_CR_RECOLOR_SCALE_FACTOR = 1.7 251 | DEFAULT_RECOLOR_CR = True 252 | DEFAULT_ADAPTIVE_PUPIL = False 253 | DEFAULT_CR_MASK_RADIUS = 10 254 | DEFAULT_PUPIL_MASK_RADIUS = 35 255 | DEFAULT_GENERATE_QC_OUTPUT = False 256 | DEFAULT_SMOOTHING_KERNEL_SIZE = 7 257 | DEFAULT_CLIP_PUPIL_VALUES = True 258 | DEFAULT_AVERAGE_IRIS_INTENSITY = 40 259 | DEFAULT_MAX_ECCENTRICITY = 0.25 260 | 261 | def __init__(self, input_stream, output_stream=None, 262 | starburst_params=None, ransac_params=None, 263 | pupil_bounding_box=None, cr_bounding_box=None, 264 | generate_QC_output=DEFAULT_GENERATE_QC_OUTPUT, **kwargs): 265 | self._mean_frame = None 266 | self._input_stream = None 267 | self.input_stream = input_stream 268 | self.point_generator = None 269 | self.ellipse_fitter = None 270 | self.min_pupil_value = self.DEFAULT_MIN_PUPIL_VALUE 271 | self.max_pupil_value = self.DEFAULT_MAX_PUPIL_VALUE 272 | self.cr_recolor_scale_factor = self.DEFAULT_CR_RECOLOR_SCALE_FACTOR 273 | self.recolor_cr = self.DEFAULT_RECOLOR_CR 274 | self.cr_mask_radius = self.DEFAULT_CR_MASK_RADIUS 275 | self.pupil_mask_radius = self.DEFAULT_PUPIL_MASK_RADIUS 276 | self.adaptive_pupil = self.DEFAULT_ADAPTIVE_PUPIL 277 | self.smoothing_kernel_size = self.DEFAULT_SMOOTHING_KERNEL_SIZE 278 | self.clip_pupil_values = self.DEFAULT_CLIP_PUPIL_VALUES 279 | self.average_iris_intensity = self.DEFAULT_AVERAGE_IRIS_INTENSITY 280 | self.max_eccentricity = self.DEFAULT_MAX_ECCENTRICITY 281 | self.update_fit_parameters(starburst_params=starburst_params, 282 | ransac_params=ransac_params, 283 | pupil_bounding_box=pupil_bounding_box, 284 | cr_bounding_box=cr_bounding_box, 285 | **kwargs) 286 | self.annotator = Annotator(output_stream) 287 | self.pupil_parameters = [] 288 | self.cr_parameters = [] 289 | self.pupil_colors = [] 290 | self.generate_QC_output = generate_QC_output 291 | self.current_seed = None 292 | self.current_pupil_candidates = None 293 | self.current_image = None 294 | self.current_image_mean = 0 295 | self.blurred_image = None 296 | self.cr_filled_image = None 297 | self.pupil_max_image = None 298 | self.annotated_image = None 299 | self.frame_index = 0 300 | 301 | def update_fit_parameters(self, starburst_params=None, ransac_params=None, 302 | pupil_bounding_box=None, cr_bounding_box=None, 303 | **kwargs): 304 | """Update EyeTracker fitting parameters. 305 | 306 | Parameters 307 | ---------- 308 | starburst_params : dict 309 | Dictionary of keyword arguments for `PointGenerator`. 310 | ransac_params : dict 311 | Dictionary of keyword arguments for `EllipseFitter`. 312 | pupil_bounding_box : numpy.ndarray 313 | [xmin xmax ymin ymax] bounding box for pupil seed point search. 314 | cr_bounding_box : numpy.ndarray 315 | [xmin xmax ymin ymax] bounding box for cr seed point search. 316 | generate_QC_output : bool 317 | Flag to compute extra QC data on frames. 318 | **kwargs 319 | pupil_min_value : int 320 | pupil_max_value : int 321 | cr_mask_radius : int 322 | pupil_mask_radius : int 323 | cr_recolor_scale_factor : float 324 | recolor_cr : bool 325 | adaptive_pupil: bool 326 | smoothing_kernel_size : int 327 | clip_pupil_values : bool 328 | average_iris_intensity : int 329 | """ 330 | if self.point_generator is None: 331 | if starburst_params is None: 332 | self.point_generator = PointGenerator() 333 | else: 334 | self.point_generator = PointGenerator(**starburst_params) 335 | elif starburst_params is not None: 336 | self.point_generator.update_params(**starburst_params) 337 | if self.ellipse_fitter is None: 338 | if ransac_params is None: 339 | self.ellipse_fitter = EllipseFitter() 340 | else: 341 | self.ellipse_fitter = EllipseFitter(**ransac_params) 342 | elif ransac_params is not None: 343 | self.ellipse_fitter.update_params(**ransac_params) 344 | if pupil_bounding_box is None or len(pupil_bounding_box) != 4: 345 | pupil_bounding_box = default_bounding_box(self.im_shape) 346 | if cr_bounding_box is None or len(cr_bounding_box) != 4: 347 | cr_bounding_box = default_bounding_box(self.im_shape) 348 | self.pupil_bounding_box = pupil_bounding_box 349 | self.cr_bounding_box = cr_bounding_box 350 | self._init_kwargs(**kwargs) 351 | self.current_seed = None 352 | self.current_pupil_candidates = None 353 | self.current_image = None 354 | self.current_image_mean = 0 355 | self.blurred_image = None 356 | self.cr_filled_image = None 357 | self.annotated_image = None 358 | 359 | def _init_kwargs(self, **kwargs): 360 | self.min_pupil_value = kwargs.get("min_pupil_value", 361 | self.min_pupil_value) 362 | self.max_pupil_value = kwargs.get("max_pupil_value", 363 | self.max_pupil_value) 364 | self.last_pupil_color = self.min_pupil_value 365 | self.cr_recolor_scale_factor = kwargs.get( 366 | "cr_recolor_scale_factor", self.cr_recolor_scale_factor) 367 | self.recolor_cr = kwargs.get("recolor_cr", self.recolor_cr) 368 | self.cr_mask_radius = kwargs.get("cr_mask_radius", self.cr_mask_radius) 369 | self.cr_mask = get_circle_template(self.cr_mask_radius, fill=1, 370 | surround=-1) 371 | self.pupil_mask_radius = kwargs.get("pupil_mask_radius", 372 | self.pupil_mask_radius) 373 | self.adaptive_pupil = kwargs.get( 374 | "adaptive_pupil", self.adaptive_pupil) 375 | self.smoothing_kernel_size = kwargs.get( 376 | "smoothing_kernel_size", self.smoothing_kernel_size) 377 | self.clip_pupil_values = kwargs.get( 378 | "clip_pupil_values", self.clip_pupil_values) 379 | if self.clip_pupil_values: 380 | self.pupil_limits = (self.min_pupil_value, 381 | self.max_pupil_value) 382 | else: 383 | self.pupil_limits = None 384 | self.average_iris_intensity = kwargs.get( 385 | "average_iris_intensity", self.average_iris_intensity) 386 | self.max_eccentricity = kwargs.get( 387 | "max_eccentricity", self.max_eccentricity) 388 | 389 | @property 390 | def im_shape(self): 391 | """Image shape.""" 392 | if self.input_stream is None: 393 | return None 394 | return self.input_stream.frame_shape 395 | 396 | @property 397 | def input_stream(self): 398 | """Input frame source.""" 399 | return self._input_stream 400 | 401 | @input_stream.setter 402 | def input_stream(self, stream): 403 | self._mean_frame = None 404 | if self._input_stream is not None: 405 | self._input_stream.close() 406 | if stream is not None and stream.frame_shape != self.im_shape: 407 | self.cr_bounding_box = default_bounding_box(stream.frame_shape) 408 | self.pupil_bounding_box = default_bounding_box(stream.frame_shape) 409 | self._input_stream = stream 410 | 411 | @property 412 | def mean_frame(self): 413 | """Average frame calculated from the input source.""" 414 | if self._mean_frame is None: 415 | mean_frame = np.zeros(self.im_shape, dtype=np.float64) 416 | frame_count = 0 417 | for frame in self.input_stream: 418 | mean_frame += frame 419 | frame_count += 1 420 | self._mean_frame = (mean_frame / frame_count).astype(np.uint8) 421 | return self._mean_frame 422 | 423 | def find_corneal_reflection(self): 424 | """Estimate the position of the corneal reflection. 425 | 426 | Returns 427 | ------- 428 | ellipse_parameters : tuple 429 | (x, y, r, a, b) ellipse parameters. 430 | """ 431 | seed_point = max_correlation_positions( 432 | self.blurred_image, self.cr_mask, self.cr_bounding_box) 433 | candidate_points = self.point_generator.get_candidate_points( 434 | self.blurred_image, seed_point, "cr") 435 | return self.ellipse_fitter.fit( 436 | candidate_points, max_radius=self.point_generator.index_length) 437 | 438 | def setup_pupil_finder(self, cr_parameters): 439 | """Initialize image and ransac filter for pupil fitting. 440 | 441 | If recoloring the corneal_reflection, color it in and provide a 442 | filter to exclude points that fall on the colored-in ellipse 443 | from fitting. 444 | 445 | Parameters 446 | ---------- 447 | cr_parameters : tuple 448 | (x, y, r, a, b) ellipse parameters for corneal reflection. 449 | 450 | Returns 451 | ------- 452 | image : numpy.ndarray 453 | Image for pupil fitting. Has corneal reflection filled in if 454 | `recolor_cr` is set. 455 | filter_function : callable 456 | Function to indicate if points fall on the recolored ellipse 457 | or None if not recoloring. 458 | filter_parameters : tuple 459 | Ellipse parameters for recolor ellipse shape, which are 460 | `cr_parameters` with the axes scaled by 461 | `cr_recolor_scale_factor`. 462 | """ 463 | if self.recolor_cr: 464 | self.recolor_corneal_reflection(cr_parameters) 465 | base_image = self.cr_filled_image 466 | filter_function = ellipse_pass_filter 467 | x, y, r, a, b = cr_parameters 468 | filter_params = (x, y, r, self.cr_recolor_scale_factor*a, 469 | self.cr_recolor_scale_factor*b) 470 | else: 471 | base_image = self.blurred_image 472 | filter_function = None 473 | filter_params = None 474 | 475 | return base_image, filter_function, filter_params 476 | 477 | def find_pupil(self, cr_parameters): 478 | """Estimate position of the pupil. 479 | 480 | Parameters 481 | ---------- 482 | cr_parameters : tuple 483 | (x, y, r, a, b) ellipse parameters of corneal reflection, 484 | used to prepare image if `recolor_cr` is set. 485 | 486 | Returns 487 | ------- 488 | ellipse_parameters : tuple 489 | (x, y, r, a, b) ellipse parameters. 490 | """ 491 | base_image, filter_function, filter_params = self.setup_pupil_finder( 492 | cr_parameters) 493 | pupil_mask = get_circle_template( 494 | self.pupil_mask_radius, 495 | int(self.last_pupil_color), 496 | int(self.average_iris_intensity)) 497 | 498 | # template matching uses top-left corner for the best match, so shift 499 | # rejection coordinates accordingly 500 | if self.recolor_cr: 501 | reject = (self._recolored_r - int(pupil_mask.shape[0]/2.0), 502 | self._recolored_c - int(pupil_mask.shape[1]/2.0)) 503 | else: 504 | reject = None 505 | seed_point = max_correlation_positions( 506 | base_image, pupil_mask, 507 | self.pupil_bounding_box, reject_coords=reject) 508 | 509 | filter_kwargs = {} 510 | if self.clip_pupil_values: 511 | filter_kwargs = {"pupil_limits": self.pupil_limits} 512 | 513 | candidate_points = self.point_generator.get_candidate_points( 514 | base_image, seed_point, "pupil", filter_function=filter_function, 515 | filter_args=(filter_params, 2), filter_kwargs=filter_kwargs) 516 | self.current_seed = seed_point 517 | self.current_pupil_candidates = candidate_points 518 | 519 | return self.ellipse_fitter.fit( 520 | candidate_points, max_radius=self.point_generator.index_length, 521 | max_eccentricity=self.max_eccentricity) 522 | 523 | def recolor_corneal_reflection(self, cr_parameters): 524 | """Reshade the corneal reflection with the last pupil color. 525 | 526 | Parameters 527 | ---------- 528 | cr_parameters : tuple 529 | (x, y, r, a, b) ellipse parameters for corneal reflection. 530 | """ 531 | x, y, r, a, b = cr_parameters 532 | a = self.cr_recolor_scale_factor*a + 1 533 | b = self.cr_recolor_scale_factor*b + 1 534 | r, c = ellipse_points((x, y, r, a, b), self.blurred_image.shape) 535 | self.cr_filled_image = self.blurred_image.copy() 536 | self.cr_filled_image[r, c] = self.last_pupil_color 537 | self._recolored_r = r 538 | self._recolored_c = c 539 | 540 | def update_last_pupil_color(self, pupil_parameters): 541 | """Update last pupil color with mean of fit. 542 | 543 | Parameters 544 | ---------- 545 | pupil_parameters : tuple 546 | (x, y, r, a, b) ellipse parameters for pupil. 547 | """ 548 | if np.any(np.isnan(pupil_parameters)): 549 | return 550 | if self.recolor_cr: 551 | image = self.cr_filled_image 552 | else: 553 | image = self.blurred_image 554 | r, c = ellipse_points(pupil_parameters, image.shape) 555 | value = int(np.mean(image[r, c])) 556 | value = max(self.min_pupil_value, value) 557 | value = min(self.max_pupil_value, value) 558 | self.last_pupil_color = value 559 | 560 | def process_image(self, image): 561 | """Process an image to find pupil and corneal reflection. 562 | 563 | Parameters 564 | ---------- 565 | image : numpy.ndarray 566 | Image to process. 567 | 568 | Returns 569 | ------- 570 | cr_parameters : tuple 571 | (x, y, r, a, b) corneal reflection parameters. 572 | pupil_parameters : tuple 573 | (x, y, r, a, b) pupil parameters. 574 | cr_error : float 575 | Ellipse fit error for best fit. 576 | pupil_error : float 577 | Ellipse fit error for best fit. 578 | """ 579 | self.current_image = image 580 | self.current_image_mean = self.current_image.mean() 581 | self.blurred_image = cv2.medianBlur(image, self.smoothing_kernel_size) 582 | try: 583 | cr_parameters, cr_error = self.find_corneal_reflection() 584 | except ValueError: 585 | logging.debug("Insufficient candidate points found for fitting " 586 | "corneal reflection at frame %s", self.frame_index) 587 | cr_parameters = (np.nan, np.nan, np.nan, np.nan, np.nan) 588 | cr_error = np.nan 589 | 590 | try: 591 | pupil_parameters, pupil_error = self.find_pupil(cr_parameters) 592 | if self.adaptive_pupil: 593 | self.update_last_pupil_color(pupil_parameters) 594 | except ValueError: 595 | logging.debug("Insufficient candidate points found for fitting " 596 | "pupil at frame %s", self.frame_index) 597 | pupil_parameters = (np.nan, np.nan, np.nan, np.nan, np.nan) 598 | pupil_error = np.nan 599 | 600 | return cr_parameters, pupil_parameters, cr_error, pupil_error 601 | 602 | def process_stream(self, start=0, stop=None, step=1, 603 | update_mean_frame=True): 604 | """Get cr and pupil parameters from frames of `input_stream`. 605 | 606 | By default this will process every frame in the input stream. 607 | 608 | Parameters 609 | ---------- 610 | start : int 611 | Index of first frame to process. Defaults to 0. 612 | stop : int 613 | Stop index for processing. Defaults to None, which runs 614 | runs until the end of the input stream. 615 | step : int 616 | Number of frames to advance at each iteration. Used to skip 617 | frames while processing. Set to 1 to process every frame, 2 618 | to process every other frame, etc. Defaults to 1. 619 | update_mean_frame : bool 620 | Whether or not to update the mean frame while processing 621 | the frames. 622 | 623 | Returns 624 | ------- 625 | cr_parameters : numpy.ndarray 626 | [n_frames,5] array of corneal reflection parameters. 627 | pupil_parameters : numpy.ndarray 628 | [n_frames,5] array of pupil parameters. 629 | cr_errors : numpy.ndarray 630 | [n_frames,] array of fit errors for corneal reflection 631 | ellipses. 632 | pupil_errors : numpy.ndarray 633 | [n_frames,] array of fit errors for pupil ellipses. 634 | """ 635 | self.pupil_parameters = [] 636 | self.cr_parameters = [] 637 | self.pupil_errors = [] 638 | self.cr_errors = [] 639 | self.pupil_colors = [] 640 | i = 0 641 | 642 | if update_mean_frame: 643 | mean_frame = np.zeros(self.im_shape, dtype=np.float64) 644 | 645 | for i, frame in enumerate(self.input_stream[start:stop:step]): 646 | if update_mean_frame: 647 | mean_frame += frame 648 | self.frame_index = start + step*i 649 | cr_parameters, pupil_parameters, cr_error, pupil_error = \ 650 | self.process_image(frame) 651 | self.cr_parameters.append(cr_parameters) 652 | self.pupil_parameters.append(pupil_parameters) 653 | self.cr_errors.append(cr_error) 654 | self.pupil_errors.append(pupil_error) 655 | self.pupil_colors.append(self.last_pupil_color) 656 | if self.annotator.output_stream is not None: 657 | self.annotated_image = self.annotator.annotate_frame( 658 | frame, pupil_parameters, cr_parameters, self.current_seed, 659 | self.current_pupil_candidates) 660 | if self.generate_QC_output: 661 | self.annotator.compute_density(frame, pupil_parameters, 662 | cr_parameters) 663 | self.annotator.clear_rc() 664 | 665 | self.annotator.close() 666 | 667 | if update_mean_frame: 668 | self._mean_frame = (mean_frame / (i+1)).astype(np.uint8) 669 | 670 | return (np.array(self.cr_parameters), np.array(self.pupil_parameters), 671 | np.array(self.cr_errors), np.array(self.pupil_errors)) 672 | 673 | 674 | def default_bounding_box(image_shape): 675 | """Calculate a default bounding box as 10% in from borders of image. 676 | 677 | Parameters 678 | ---------- 679 | image_shape : tuple 680 | (height, width) of image. 681 | 682 | Returns 683 | ------- 684 | bounding_box : numpy.ndarray 685 | [xmin, xmax, ymin, ymax] bounding box. 686 | """ 687 | if image_shape is None: 688 | return np.array([1, -1, 1, -1], dtype='int') 689 | 690 | h, w = image_shape 691 | x_crop = int(0.1*w) 692 | y_crop = int(0.1*h) 693 | 694 | return np.array([x_crop, w-x_crop, y_crop, h-y_crop], dtype='int') 695 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | _CIRCLE_TEMPLATES = {} 5 | 6 | 7 | def get_circle_template(radius, fill=1, surround=0): 8 | """Get circular template for estimating center point. 9 | 10 | Returns a cached template if it has already been computed. 11 | 12 | Parameters 13 | ---------- 14 | radius : int 15 | Radius in pixels of the circle to draw. 16 | 17 | Returns 18 | ------- 19 | template : numpy.ndarray 20 | Circle template. 21 | """ 22 | global _CIRCLE_TEMPLATES 23 | mask = _CIRCLE_TEMPLATES.get((radius, int(fill), int(surround)), None) 24 | if mask is None: 25 | Y, X = np.meshgrid(np.arange(-radius-3, radius+4), 26 | np.arange(-radius-3, radius+4)) 27 | mask = np.ones([2*radius + 7, 2*radius + 7], dtype=np.float)*surround 28 | circle = X**2 + Y**2 < radius**2 29 | mask[circle] = fill 30 | _CIRCLE_TEMPLATES[(int(radius), int(fill), int(surround))] = mask 31 | return mask 32 | 33 | 34 | def max_correlation_positions(image, template, bounding_box=None, 35 | reject_coords=None): 36 | """Correlate image with template and return the max location. 37 | 38 | Correlation is done with mode set to `mode` and method 'fft'. It 39 | is only performed over the image region within `bounding_box` if it 40 | is provided. The resulting coordinates are provided in the context 41 | of the original image. 42 | 43 | Parameters 44 | ---------- 45 | image : numpy.ndarray 46 | Image over which to convolve the kernel. 47 | template : numpy.ndarray 48 | Kernel to convolve with the image. 49 | bounding_box : numpy.ndarray 50 | [xmin, xmax, ymin, ymax] bounding box on the image. 51 | reject_coords : tuple 52 | (r, c) coordinates to disallow as best fit. 53 | 54 | Returns 55 | ------- 56 | max_position : tuple 57 | (y, x) mean location maximum of the convolution of the kernel 58 | with the image. 59 | """ 60 | if bounding_box is None: 61 | cropped_image = image 62 | xmin = 0 63 | ymin = 0 64 | else: 65 | xmin, xmax, ymin, ymax = bounding_box 66 | cropped_image = image[ymin:ymax, xmin:xmax] 67 | 68 | corr = cv2.matchTemplate(cropped_image.astype(np.float32), 69 | template.astype(np.float32), 70 | cv2.TM_CCORR_NORMED) 71 | 72 | if reject_coords: 73 | r = reject_coords[0] - ymin - template.shape[0] 74 | c = reject_coords[1] - xmin - template.shape[1] 75 | idx = (r >= 0) & (c >= 0) 76 | corr[r[idx], c[idx]] = -np.inf 77 | 78 | _, _, _, max_loc = cv2.minMaxLoc(corr) 79 | 80 | y = int(max_loc[1] + template.shape[0]/2.0 + ymin) 81 | x = int(max_loc[0] + template.shape[1]/2.0 + xmin) 82 | 83 | return y, x 84 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/fit_ellipse.py: -------------------------------------------------------------------------------- 1 | """Module for ellipse fitting. 2 | 3 | The algorithm for the actual fitting is detailed at 4 | http://nicky.vanforeest.com/misc/fitEllipse/fitEllipse.html. 5 | """ 6 | import numpy as np 7 | from .ransac import RansacFitter 8 | import logging 9 | 10 | 11 | CONSTRAINT_MATRIX = np.zeros([6, 6]) 12 | CONSTRAINT_MATRIX[0, 2] = 2.0 13 | CONSTRAINT_MATRIX[2, 0] = 2.0 14 | CONSTRAINT_MATRIX[1, 1] = -1.0 15 | 16 | 17 | class EllipseFitter(object): 18 | """Wrapper class for performing ransac fitting of an ellipse. 19 | 20 | Parameters 21 | ---------- 22 | minimum_points_for_fit : int 23 | Number of points required to fit data. 24 | number_of_close_points : int 25 | Number of candidate outliers reselected as inliers required 26 | to consider a good fit. 27 | threshold : float 28 | Error threshold below which data should be considered an 29 | inlier. 30 | iterations : int 31 | Number of iterations to run. 32 | """ 33 | DEFAULT_MINIMUM_POINTS_FOR_FIT = 40 34 | DEFAULT_NUMBER_OF_CLOSE_POINTS = 15 35 | DEFAULT_THRESHOLD = 0.0001 36 | DEFAULT_ITERATIONS = 20 37 | 38 | def __init__(self, minimum_points_for_fit=DEFAULT_MINIMUM_POINTS_FOR_FIT, 39 | number_of_close_points=DEFAULT_NUMBER_OF_CLOSE_POINTS, 40 | threshold=DEFAULT_THRESHOLD, iterations=DEFAULT_ITERATIONS): 41 | self.update_params(minimum_points_for_fit=minimum_points_for_fit, 42 | number_of_close_points=number_of_close_points, 43 | iterations=iterations, threshold=threshold) 44 | self._fitter = RansacFitter() 45 | 46 | def update_params(self, 47 | minimum_points_for_fit=DEFAULT_MINIMUM_POINTS_FOR_FIT, 48 | number_of_close_points=DEFAULT_NUMBER_OF_CLOSE_POINTS, 49 | threshold=DEFAULT_THRESHOLD, 50 | iterations=DEFAULT_ITERATIONS): 51 | self.minimum_points_for_fit = minimum_points_for_fit 52 | self.number_of_close_points = number_of_close_points 53 | self.threshold = threshold 54 | self.iterations = iterations 55 | 56 | def fit(self, candidate_points, **kwargs): 57 | """Perform a fit on (y,x) points. 58 | 59 | Parameters 60 | ---------- 61 | candidate_points : list 62 | List of (y,x) points that may fit on the ellipse. 63 | 64 | Returns 65 | ------- 66 | ellipse_parameters : tuple 67 | (x, y, angle, semi_axis1, semi_axis2) ellipse parameters. 68 | error : float 69 | Fit error for the ellipse. 70 | """ 71 | data = np.array(candidate_points) 72 | params, error = self._fitter.fit( 73 | fit_ellipse, fit_errors, data, self.threshold, 74 | self.minimum_points_for_fit, self.number_of_close_points, 75 | self.iterations, **kwargs) 76 | if params is not None: 77 | x, y = ellipse_center(params) 78 | angle = ellipse_angle_of_rotation(params)*180/np.pi 79 | ax1, ax2 = ellipse_axis_length(params) 80 | return (x, y, angle, ax1, ax2), error 81 | else: 82 | return (np.nan, np.nan, np.nan, np.nan, np.nan), np.nan 83 | 84 | 85 | def fit_ellipse(data, max_radius=None, max_eccentricity=None): 86 | """Fit an ellipse to data. 87 | 88 | Parameters 89 | ---------- 90 | data : numpy.ndarray 91 | [n,2] array of (y,x) data points. 92 | max_radius : float 93 | Maximum radius to allow. 94 | max_eccentricity : float 95 | Maximum eccentricity to allow. 96 | 97 | Returns 98 | ------- 99 | ellipse_parameters : tuple 100 | (x, y, angle, semi_axis1, semi_axis2) ellipse parameters. 101 | error : float 102 | Mean error of the fit. 103 | """ 104 | try: 105 | y, x = data.T 106 | 107 | D = np.vstack([x*x, x*y, y*y, x, y, np.ones(len(y))]) 108 | S = np.dot(D, D.T) 109 | 110 | M = np.dot(np.linalg.inv(S), CONSTRAINT_MATRIX) 111 | U, s, V = np.linalg.svd(M) 112 | 113 | params = U.T[0] 114 | error = np.dot(params, np.dot(S, params))/len(y) 115 | if max_radius is not None: 116 | ax1, ax2 = ellipse_axis_length(params) 117 | if ax1 > max_radius or ax2 > max_radius: 118 | error = np.inf 119 | if max_eccentricity is not None: 120 | if eccentricity(params) > max_eccentricity: 121 | error = np.inf 122 | except Exception as e: 123 | logging.debug(e) # figure out which exception this is catching 124 | params = None 125 | error = np.inf 126 | 127 | return params, error 128 | 129 | 130 | def fit_errors(parameters, data): 131 | """Calculate the errors on each data point. 132 | 133 | Parameters 134 | ---------- 135 | parameters : numpy.ndarray 136 | Paramaters of the fit ellipse model. 137 | data : numpy.ndarray 138 | [n,2] array of (y,x) points. 139 | 140 | Returns 141 | ------- 142 | numpy.ndarray 143 | Squared error of the fit at each point in data. 144 | """ 145 | y, x = data.T 146 | D = np.vstack([x*x, x*y, y*y, x, y, np.ones(len(y))]) 147 | errors = (np.dot(parameters, D))**2 148 | 149 | return errors 150 | 151 | 152 | def quadratic_parameters(conic_parameters): 153 | """Get quadratic ellipse coefficients from conic parameters. 154 | 155 | Calculation from http://mathworld.wolfram.com/Ellipse.html 156 | 157 | Parameters 158 | ---------- 159 | conic_parameters : tuple 160 | (x, y, angle, semi_axis1, semi_axis2) ellipse parameters. 161 | 162 | Returns 163 | ------- 164 | quadratic_parameters : tuple 165 | Polynomial parameters for the ellipse. 166 | """ 167 | a = conic_parameters[0] 168 | b = conic_parameters[1]/2 169 | c = conic_parameters[2] 170 | d = conic_parameters[3]/2 171 | f = conic_parameters[4]/2 172 | g = conic_parameters[5] 173 | return (a, b, c, d, f, g) 174 | 175 | 176 | def ellipse_center(parameters): 177 | """Calculate the center of the ellipse given the model parameters. 178 | 179 | Calculation from http://mathworld.wolfram.com/Ellipse.html 180 | 181 | Parameters 182 | ---------- 183 | parameters : numpy.ndarray 184 | Parameters of the ellipse fit. 185 | 186 | Returns 187 | ------- 188 | center : numpy.ndarray 189 | [x,y] center of the ellipse. 190 | """ 191 | a, b, c, d, f, g = quadratic_parameters(parameters) 192 | num = b*b-a*c 193 | x0 = (c*d-b*f)/num 194 | y0 = (a*f-b*d)/num 195 | return np.array([x0, y0]) 196 | 197 | 198 | def ellipse_angle_of_rotation(parameters): 199 | """Calculate the rotation of the ellipse given the model parameters. 200 | 201 | Calculation from http://mathworld.wolfram.com/Ellipse.html 202 | 203 | Parameters 204 | ---------- 205 | parameters : numpy.ndarray 206 | Parameters of the ellipse fit. 207 | 208 | Returns 209 | ------- 210 | rotation : float 211 | Rotation of the ellipse. 212 | """ 213 | a, b, c, d, f, g = quadratic_parameters(parameters) 214 | return 0.5*np.arctan(2*b/(a-c)) 215 | 216 | 217 | def ellipse_axis_length(parameters): 218 | """Calculate the semi-axes lengths of the ellipse. 219 | 220 | Calculation from http://mathworld.wolfram.com/Ellipse.html 221 | 222 | Parameters 223 | ---------- 224 | parameters : numpy.ndarray 225 | Parameters of the ellipse fit. 226 | 227 | Returns 228 | ------- 229 | semi_axes : numpy.ndarray 230 | Semi-axes of the ellipse. 231 | """ 232 | a, b, c, d, f, g = quadratic_parameters(parameters) 233 | up = 2*(a*f*f+c*d*d+g*b*b-2*b*d*f-a*c*g) 234 | down1 = (b*b-a*c)*((c-a)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a)) 235 | down2 = (b*b-a*c)*((a-c)*np.sqrt(1+4*b*b/((a-c)*(a-c)))-(c+a)) 236 | 237 | down1 = min(.0000000001, down1) 238 | down2 = min(.0000000001, down2) 239 | 240 | res1 = np.sqrt(up/down1) 241 | res2 = np.sqrt(up/down2) 242 | return np.array([res1, res2]) 243 | 244 | 245 | def not_on_ellipse(point, ellipse_params, tolerance): 246 | """Function that tests if a point is not on an ellipse. 247 | 248 | Parameters 249 | ---------- 250 | point : tuple 251 | (y, x) point. 252 | ellipse_params : numpy.ndarray 253 | Ellipse parameters to check against. 254 | tolerance : float 255 | Tolerance for determining point is on ellipse. 256 | 257 | Returns 258 | ------ 259 | not_on : bool 260 | True if `point` is not within `tolerance` of the ellipse. 261 | """ 262 | py, px = point 263 | x, y, r, a, b = ellipse_params 264 | r = np.radians(r) 265 | # get point in frame of unrotated ellipse at 0, 0 266 | tx = (px - x)*np.cos(-r) - (py-y)*np.sin(-r) 267 | ty = (px - x)*np.sin(-r) + (py-y)*np.cos(-r) 268 | err = np.abs((tx**2 / a**2) + (ty**2 / b**2) - 1) 269 | if err < tolerance: 270 | return False 271 | return True 272 | 273 | 274 | def ellipse_pass_filter(point, ellipse_params, tolerance, 275 | pupil_intensity_estimate=None, 276 | pupil_limits=None): 277 | """Function to pass or reject an ellipse candidate point. 278 | 279 | Points are rejected if they fall on the border defined by 280 | `ellipse_params`. If `pupil_limits` is provided and 281 | `pupil_intensity_limits` falls outside it the point is 282 | rejected as well. 283 | 284 | Parameters 285 | ---------- 286 | point : tuple 287 | (y, x) point. 288 | ellipse_params : numpy.ndarray 289 | Ellipse parameters to check against. 290 | tolerance : float 291 | Tolerance for determining point is on ellipse. 292 | pupil_intensity_estimage : float 293 | Estimated intensity of the pupil used for generating 294 | the point. 295 | pupil_limits : tuple 296 | (min, max) valid intensities for the pupil. 297 | 298 | Returns 299 | ------ 300 | passed : bool 301 | True if the point passes the filter and is a good candidate 302 | for fitting. 303 | """ 304 | passed = not_on_ellipse(point, ellipse_params, tolerance) 305 | if (pupil_limits is not None) and passed: 306 | in_range = (pupil_intensity_estimate >= pupil_limits[0]) and \ 307 | (pupil_intensity_estimate <= pupil_limits[1]) 308 | passed = in_range 309 | return passed 310 | 311 | 312 | def eccentricity(parameters): 313 | """Get the eccentricity of an ellipse from the conic parameters. 314 | 315 | Parameters 316 | ---------- 317 | parameters : numpy.ndarray 318 | Conic parameters of the ellipse. 319 | 320 | Returns 321 | ------- 322 | eccentricity : float 323 | Eccentricity of the ellipse. 324 | """ 325 | axes = ellipse_axis_length(parameters) 326 | minor = np.min(axes) 327 | major = np.max(axes) 328 | return 1 - (minor/major) 329 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/frame_stream.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import traceback 3 | import cv2 4 | 5 | 6 | class FrameInputStream(object): 7 | def __init__(self, movie_path, num_frames=None, process_frame_cb=None): 8 | self.movie_path = movie_path 9 | self._num_frames = num_frames 10 | self._start = 0 11 | self._stop = num_frames 12 | self._step = 1 13 | self._i = self._start - self._step 14 | self._last_i = 0 15 | if process_frame_cb: 16 | self.process_frame_cb = process_frame_cb 17 | else: 18 | self.process_frame_cb = lambda f: f[:, :, 0].copy() 19 | self.frames_read = 0 20 | self.frame_cache = [] 21 | 22 | def next(self): 23 | return self.__next__() 24 | 25 | def __getitem__(self, key): 26 | if isinstance(key, int): 27 | if key >= self.num_frames or key < -self.num_frames: 28 | raise IndexError("Index {} out of range".format(key)) 29 | elif key >= 0: 30 | self._start = key 31 | else: 32 | self._start = self.num_frames + key 33 | self._stop = self._start + 1 34 | self._step = 1 35 | return list(self)[0] # force iteration and closing 36 | elif isinstance(key, slice): 37 | if key.step == 0: 38 | raise ValueError("slice step cannot be 0") 39 | self._start = key.start if key.start is not None else 0 40 | if key.stop is None: 41 | self._stop = self.num_frames 42 | elif key.stop < 0: 43 | self._stop = max(self.num_frames + key.stop, -1) 44 | else: 45 | self._stop = min(self.num_frames, key.stop) 46 | self._step = key.step if key.step is not None else 1 47 | return self 48 | else: 49 | raise KeyError("Key must be non-negative integer or slice, not {}" 50 | .format(key)) 51 | 52 | @property 53 | def num_frames(self): 54 | return self._num_frames if self._num_frames is not None else 0 55 | 56 | @property 57 | def frame_shape(self): 58 | raise NotImplementedError(("frame_shape must be implemented in a " 59 | "subclass")) 60 | 61 | def open(self): 62 | self.frames_read = 0 63 | 64 | def close(self): 65 | logging.debug("Read total frames %d", self.frames_read) 66 | 67 | def _error(self): 68 | pass 69 | 70 | def _seek_frame(self, i): 71 | raise NotImplementedError(("_seek_frame must be implemented in a " 72 | "subclass")) 73 | 74 | def _get_frame(self, i): 75 | raise NotImplementedError(("_get_frame must be implemented in a " 76 | "subclass")) 77 | 78 | def get_frame(self, i): 79 | if abs(i - self._last_i) > 1: 80 | self._seek_frame(i) 81 | self._last_i = self._i 82 | self._i = i 83 | self.frames_read += 1 84 | if self.frames_read % 100 == 0: 85 | logging.debug("Read frames %d", self.frames_read) 86 | return self.process_frame_cb(self._get_frame(self._i)) 87 | 88 | def __enter__(self): 89 | return self 90 | 91 | def __iter__(self): 92 | self._last_i = 0 93 | self._i = self._start - self._step 94 | self.open() 95 | logging.debug("Iterating over %s from %d to %s by step %d" % 96 | (self.movie_path, self._start, self._stop, self._step)) 97 | return self 98 | 99 | def __next__(self): 100 | if self._stop is None: 101 | self._stop = self.num_frames 102 | self._i = self._i + self._step 103 | if (self._step < 0 and self._i <= self._stop) or \ 104 | (self._step > 0 and self._i >= self._stop): 105 | self.close() 106 | raise StopIteration() 107 | else: 108 | return self.get_frame(self._i) 109 | 110 | def __exit__(self, exc_type, exc_value, tb): 111 | if exc_value: 112 | traceback.print_tb(tb) 113 | self._error() 114 | raise exc_value 115 | 116 | 117 | class CvInputStream(FrameInputStream): 118 | def __init__(self, movie_path, num_frames=None, process_frame_cb=None): 119 | super(CvInputStream, self).__init__(movie_path=movie_path, 120 | num_frames=num_frames, 121 | process_frame_cb=process_frame_cb) 122 | self.cap = None 123 | self._frame_shape = None 124 | self._stop = num_frames 125 | 126 | @property 127 | def num_frames(self): 128 | if self._num_frames is None: 129 | self.load_capture_properties() 130 | self._stop = self._num_frames 131 | return self._num_frames 132 | 133 | @property 134 | def frame_shape(self): 135 | if self._frame_shape is None: 136 | self.load_capture_properties() 137 | return self._frame_shape 138 | 139 | def load_capture_properties(self): 140 | close_after = False 141 | if self.cap is None: 142 | close_after = True 143 | self.cap = cv2.VideoCapture(self.movie_path) 144 | 145 | self._num_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) 146 | self._frame_shape = (int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)), 147 | int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))) 148 | 149 | if close_after: 150 | self.cap.release() 151 | self.cap = None 152 | 153 | def open(self): 154 | if self.cap: 155 | raise IOError("capture is open already") 156 | 157 | super(CvInputStream, self).open() 158 | 159 | self.cap = cv2.VideoCapture(self.movie_path) 160 | logging.debug("opened capture") 161 | 162 | def close(self): 163 | if self.cap is None: 164 | return 165 | 166 | self.cap.release() 167 | self.cap = None 168 | 169 | super(CvInputStream, self).close() 170 | 171 | def _seek_frame(self, i): 172 | if self.cap is None: 173 | raise IOError("capture is not open") 174 | self.cap.set(cv2.CAP_PROP_POS_FRAMES, i) 175 | 176 | def _get_frame(self, i): 177 | if self.cap is None: 178 | raise IOError("capture is not open") 179 | ret, frame = self.cap.read() 180 | return frame 181 | 182 | def _error(self): 183 | self.cap.release() 184 | self.cap = None 185 | 186 | 187 | class FrameOutputStream(object): 188 | def __init__(self, block_size=1): 189 | self.frames_processed = 0 190 | self.block_frames = [] 191 | self.block_size = block_size 192 | 193 | def open(self, movie_path): 194 | self.frames_processed = 0 195 | self.block_frames = [] 196 | self.movie_path = movie_path 197 | 198 | def _write_frames(self, frames): 199 | raise NotImplementedError() 200 | 201 | def write(self, frame): 202 | self.block_frames.append(frame) 203 | 204 | if len(self.block_frames) == self.block_size: 205 | self._write_frames(self.block_frames) 206 | self.frames_processed += len(self.block_frames) 207 | self.block_frames = [] 208 | 209 | def close(self): 210 | if self.block_frames: 211 | self._write_frames(self.block_frames) 212 | self.frames_processed += len(self.block_frames) 213 | self.block_frames = [] 214 | 215 | logging.debug("wrote %d frames", self.frames_processed) 216 | 217 | def __enter__(self): 218 | return self 219 | 220 | def __exit__(self, exc_type, exc_value, tb): 221 | if exc_value: 222 | raise exc_value 223 | self.close() 224 | 225 | 226 | class CvOutputStream(FrameOutputStream): 227 | def __init__(self, movie_path, frame_shape, frame_rate=30.0, 228 | fourcc="H264", is_color=True, block_size=1): 229 | super(CvOutputStream, self).__init__(block_size) 230 | 231 | self.frame_shape = frame_shape 232 | self.movie_path = movie_path 233 | self.fourcc = cv2.VideoWriter_fourcc(*str(fourcc)) 234 | self.frame_rate = frame_rate 235 | self.is_color = is_color 236 | self.writer = None 237 | 238 | def open(self, movie_path): 239 | super(CvOutputStream, self).open(movie_path) 240 | 241 | if self.writer: 242 | raise IOError("video writer is open already") 243 | 244 | self.writer = cv2.VideoWriter(movie_path, self.fourcc, 245 | self.frame_rate, self.frame_shape, 246 | self.is_color) 247 | logging.debug("opened video writer") 248 | 249 | def _write_frames(self, frames): 250 | if self.writer is None: 251 | self.open(self.movie_path) 252 | 253 | for frame in frames: 254 | self.writer.write(frame) 255 | 256 | def close(self): 257 | super(CvOutputStream, self).close() 258 | if self.writer is None: 259 | raise IOError("video writer is closed") 260 | 261 | self.writer.release() 262 | 263 | logging.debug("closed video writer") 264 | self.writer = None 265 | 266 | def __exit__(self, exc_type, exc_value, tb): 267 | if exc_value: 268 | self.writer.release() 269 | raise exc_value 270 | self.close() 271 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/plotting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import numpy as np 4 | from skimage.draw import ellipse, ellipse_perimeter, polygon_perimeter 5 | import matplotlib 6 | with warnings.catch_warnings(): 7 | warnings.simplefilter("ignore") 8 | matplotlib.use("Agg") 9 | from matplotlib import pyplot as plt # noqa: E402 10 | 11 | 12 | class Annotator(object): 13 | """Class for annotating frames with ellipses. 14 | 15 | Parameters 16 | ---------- 17 | output_stream : object 18 | Object that implements a `write` method that accepts ndarray 19 | frames as well as `open` and `close` methods. 20 | """ 21 | COLORS = {"cr": (0, 0, 255), 22 | "pupil": (255, 0, 0)} 23 | 24 | def __init__(self, output_stream=None): 25 | self.output_stream = output_stream 26 | self.densities = {"pupil": None, 27 | "cr": None} 28 | self.clear_rc() 29 | 30 | def initiate_cumulative_data(self, shape): 31 | """Initialize density arrays to zeros of the correct shape. 32 | 33 | Parameters 34 | ---------- 35 | shape : tuple 36 | (height, width) to make the density arrays. 37 | """ 38 | self.densities["cr"] = np.zeros(shape, dtype=float) 39 | self.densities["pupil"] = np.zeros(shape, dtype=float) 40 | 41 | def clear_rc(self): 42 | """Clear the cached row and column ellipse border points.""" 43 | self._r = {"pupil": None, 44 | "cr": None} 45 | self._c = {"pupil": None, 46 | "cr": None} 47 | 48 | def update_rc(self, name, ellipse_parameters, shape): 49 | """Cache new row and column ellipse border points. 50 | 51 | Parameters 52 | ---------- 53 | name : string 54 | "pupil" or "cr" to reference the correct object in the 55 | lookup table. 56 | ellipse_parameters : tuple 57 | Conic parameters of the ellipse. 58 | shape : tuple 59 | (height, width) shape of image used to generate ellipse 60 | border points at the right rows and columns. 61 | 62 | Returns 63 | ------- 64 | cache_updated : bool 65 | Whether or not new values were cached. 66 | """ 67 | if np.any(np.isnan(ellipse_parameters)): 68 | return False 69 | if self._r[name] is None: 70 | self._r[name], self._c[name] = ellipse_perimeter_points( 71 | ellipse_parameters, shape) 72 | return True 73 | 74 | def _annotate(self, name, rgb_frame, ellipse_parameters): 75 | if self.update_rc(name, ellipse_parameters, rgb_frame.shape[:2]): 76 | color_by_points(rgb_frame, self._r[name], self._c[name], 77 | self.COLORS[name]) 78 | 79 | def annotate_frame(self, frame, pupil_parameters, cr_parameters, 80 | seed=None, pupil_candidates=None): 81 | """Annotate an image with ellipses for cr and pupil. 82 | 83 | If the annotator was initialized with an output stream, the 84 | frame will be written to the stream. 85 | 86 | Parameters 87 | ---------- 88 | frame : numpy.ndarray 89 | Grayscale image to annotate. 90 | pupil_parameters : tuple 91 | (x, y, r, a, b) ellipse parameters for pupil. 92 | cr_parameters : tuple 93 | (x, y, r, a, b) ellipse parameters for corneal reflection. 94 | seed : tuple 95 | (y, x) seed point of pupil. 96 | pupil_candidates : list 97 | List of (y, x) candidate points used for the ellipse 98 | fit of the pupil. 99 | 100 | Returns 101 | ------- 102 | rgb_frame : numpy.ndarray 103 | Color annotated frame. 104 | """ 105 | rgb_frame = get_rgb_frame(frame) 106 | if not np.any(np.isnan(pupil_parameters)): 107 | self._annotate("pupil", rgb_frame, pupil_parameters) 108 | if not np.any(np.isnan(cr_parameters)): 109 | self._annotate("cr", rgb_frame, cr_parameters) 110 | 111 | if seed is not None: 112 | color_by_points(rgb_frame, seed[0], seed[1], (0, 255, 0)) 113 | 114 | if pupil_candidates: 115 | arr = np.array(pupil_candidates) 116 | color_by_points(rgb_frame, arr[:, 0], arr[:, 1], (0, 255, 0)) 117 | 118 | if self.output_stream is not None: 119 | self.output_stream.write(rgb_frame) 120 | return rgb_frame 121 | 122 | def _density(self, name, frame, ellipse_parameters): 123 | if self.update_rc(name, ellipse_parameters, frame.shape): 124 | self.densities[name][self._r[name], self._c[name]] += 1 125 | 126 | def compute_density(self, frame, pupil_parameters, cr_parameters): 127 | """Update the density maps with from the current frame. 128 | 129 | Parameters 130 | ---------- 131 | frame : numpy.ndarray 132 | Input frame. 133 | pupil_parameters : tuple 134 | (x, y, r, a, b) ellipse parameters for pupil. 135 | cr_parameters : tuple 136 | (x, y, r, a, b) ellipse parameters for corneal reflection. 137 | """ 138 | # TODO: rename this to update_density 139 | if self.densities["pupil"] is None: 140 | self.initiate_cumulative_data(frame.shape) 141 | self._density("pupil", frame, pupil_parameters) 142 | self._density("cr", frame, cr_parameters) 143 | 144 | def annotate_with_cumulative_pupil(self, frame, filename=None): 145 | """Annotate frame with all pupil ellipses from the density map. 146 | 147 | Parameters 148 | ---------- 149 | frame : numpy.ndarray 150 | Grayscale frame to annotate. 151 | filename : string 152 | Filename to save annotated image to, if provided. 153 | 154 | Returns 155 | ------- 156 | rgb_frame : numpy.ndarray 157 | Annotated color frame. 158 | """ 159 | return annotate_with_cumulative(frame, self.densities["pupil"], 160 | (0, 0, 255), filename) 161 | 162 | def annotate_with_cumulative_cr(self, frame, filename=None): 163 | """Annotate frame with all cr ellipses from the density map. 164 | 165 | Parameters 166 | ---------- 167 | frame : numpy.ndarray 168 | Grayscale frame to annotate. 169 | filename : string 170 | Filename to save annotated image to, if provided. 171 | 172 | Returns 173 | ------- 174 | rgb_frame : numpy.ndarray 175 | Annotated color frame. 176 | """ 177 | return annotate_with_cumulative(frame, self.densities["cr"], 178 | (255, 0, 0), filename) 179 | 180 | def close(self): 181 | """Close the output stream if it exists.""" 182 | if self.output_stream is not None: 183 | self.output_stream.close() 184 | 185 | 186 | def get_rgb_frame(frame): 187 | """Convert a grayscale frame to an RGB frame. 188 | 189 | If the frame passed in already has 3 channels, it is simply returned. 190 | 191 | Parameters 192 | ---------- 193 | frame : numpy.ndarray 194 | Image frame. 195 | 196 | Returns 197 | ------- 198 | rgb_frame : numpy.ndarray 199 | [height,width,3] RGB frame. 200 | """ 201 | if frame.ndim == 3 and frame.shape[2] == 3: 202 | rgb_frame = frame 203 | elif frame.ndim == 2: 204 | rgb_frame = np.dstack([frame, frame, frame]) 205 | else: 206 | raise ValueError("Frame of shape {} is not valid".format(frame.shape)) 207 | return rgb_frame 208 | 209 | 210 | def annotate_with_cumulative(frame, density, rgb_vals=(255, 0, 0), 211 | filename=None): 212 | """Annotate frame with all values from `density`. 213 | 214 | Parameters 215 | ---------- 216 | frame : numpy.ndarray 217 | Grayscale frame to annotate. 218 | density : numpy.ndarray 219 | Array of the same shape as frame with non-zero values 220 | where the image should be annotated. 221 | rgb_vals : tuple 222 | (r, g, b) 0-255 color values for annotation. 223 | filename : string 224 | Filename to save annotated image to, if provided. 225 | 226 | Returns 227 | ------- 228 | rgb_frame : numpy.ndarray 229 | Annotated color frame. 230 | """ 231 | rgb_frame = get_rgb_frame(frame) 232 | if density is not None: 233 | mask = density > 0 234 | color_by_mask(rgb_frame, mask, rgb_vals) 235 | if filename is not None: 236 | plt.imsave(filename, rgb_frame) 237 | return rgb_frame 238 | 239 | 240 | def annotate_with_box(image, bounding_box, rgb_vals=(255, 0, 0), 241 | filename=None): 242 | """Annotate image with bounding box. 243 | 244 | Parameters 245 | ---------- 246 | image : numpy.ndarray 247 | Grayscale or RGB image to annotate. 248 | bounding_box : numpy.ndarray 249 | [xmin, xmax, ymin, ymax] bounding box. 250 | rgb_vals : tuple 251 | (r, g, b) 0-255 color values for annotation. 252 | filename : string 253 | Filename to save annotated image to, if provided. 254 | 255 | Returns 256 | ------- 257 | rgb_image : numpy.ndarray 258 | Annotated color image. 259 | """ 260 | rgb_image = get_rgb_frame(image) 261 | xmin, xmax, ymin, ymax = bounding_box 262 | r = np.array((ymin, ymin, ymax, ymax), dtype=int) 263 | c = np.array((xmin, xmax, xmax, xmin), dtype=int) 264 | rr, cc = polygon_perimeter(r, c, rgb_image.shape[:2]) 265 | color_by_points(rgb_image, rr, cc, rgb_vals) 266 | if filename is not None: 267 | plt.imsave(filename, rgb_image) 268 | return rgb_image 269 | 270 | 271 | def color_by_points(rgb_image, row_points, column_points, 272 | rgb_vals=(255, 0, 0)): 273 | """Color image at points indexed by row and column vectors. 274 | 275 | The image is recolored in-place. 276 | 277 | Parameters 278 | ---------- 279 | rgb_image : numpy.ndarray 280 | Color image to draw into. 281 | row_points : numpy.ndarray 282 | Vector of row indices to color in. 283 | column_points : numpy.ndarray 284 | Vector of column indices to color in. 285 | rgb_vals : tuple 286 | (r, g, b) 0-255 color values for annotation. 287 | """ 288 | for i, value in enumerate(rgb_vals): 289 | rgb_image[row_points, column_points, i] = value 290 | 291 | 292 | def color_by_mask(rgb_image, mask, rgb_vals=(255, 0, 0)): 293 | """Color image at points indexed by mask. 294 | 295 | The image is recolored in-place. 296 | 297 | Parameters 298 | ---------- 299 | rgb_image : numpy.ndarray 300 | Color image to draw into. 301 | mask : numpy.ndarray 302 | Boolean mask of points to color in. 303 | rgb_vals : tuple 304 | (r, g, b) 0-255 color values for annotation. 305 | """ 306 | for i, value in enumerate(rgb_vals): 307 | rgb_image[mask, i] = value 308 | 309 | 310 | def ellipse_points(ellipse_params, image_shape): 311 | """Generate row, column indices for filled ellipse. 312 | 313 | Parameters 314 | ---------- 315 | ellipse_params : tuple 316 | (x, y, r, a b) ellipse parameters. 317 | image_shape : tuple 318 | (height, width) shape of image. 319 | 320 | Returns 321 | ------- 322 | row_points : numpy.ndarray 323 | Row indices for filled ellipse. 324 | column_points : numpy.ndarray 325 | Column indices for filled ellipse. 326 | """ 327 | x, y, r, a, b = ellipse_params 328 | r = np.radians(-r) 329 | return ellipse(y, x, b, a, image_shape, r) 330 | 331 | 332 | def ellipse_perimeter_points(ellipse_params, image_shape): 333 | """Generate row, column indices for ellipse perimeter. 334 | 335 | Parameters 336 | ---------- 337 | ellipse_params : tuple 338 | (x, y, r, a b) ellipse parameters. 339 | image_shape : tuple 340 | (height, width) shape of image. 341 | 342 | Returns 343 | ------- 344 | row_points : numpy.ndarray 345 | Row indices for ellipse perimeter. 346 | column_points : numpy.ndarray 347 | Column indices for ellipse perimeter. 348 | """ 349 | x, y, r, a, b = ellipse_params 350 | r = np.radians(r) 351 | return ellipse_perimeter(int(y), int(x), int(b), int(a), r, image_shape) 352 | 353 | 354 | def get_filename(output_folder, prefix, image_type): 355 | """Helper function to build image filename. 356 | 357 | Parameters 358 | ---------- 359 | output_folder : string 360 | Folder for images. 361 | prefix : string 362 | Image filename without extension. 363 | image_type : string 364 | File extension for image (e.g. '.png'). 365 | 366 | Returns 367 | ------- 368 | filename : string 369 | Fill filename of image, or None if no output folder. 370 | """ 371 | if output_folder: 372 | filename = prefix + image_type 373 | return os.path.join(output_folder, filename) 374 | return None 375 | 376 | 377 | def plot_cumulative(pupil_density, cr_density, output_dir=None, show=False, 378 | image_type=".png"): 379 | """Plot cumulative density of ellipse fits for cr and pupil. 380 | 381 | Parameters 382 | ---------- 383 | pupil_density : numpy.ndarray 384 | Accumulated density of pupil perimeters. 385 | pupil_density : numpy.ndarray 386 | Accumulated density of cr perimeters. 387 | output_dir : string 388 | Output directory to store images. Images aren't saved if None 389 | is provided. 390 | show : bool 391 | Whether or not to call pyplot.show() after generating both 392 | plots. 393 | image_type : string 394 | Image extension for saving plots. 395 | """ 396 | dens = np.log(1+pupil_density) 397 | plot_density(np.max(dens) - dens, 398 | filename=get_filename(output_dir, "pupil_density", 399 | image_type), 400 | title="pupil density", 401 | show=False) 402 | dens = np.log(1+cr_density) 403 | plot_density(np.max(dens) - dens, 404 | filename=get_filename(output_dir, "cr_density", 405 | image_type), 406 | title="cr density", 407 | show=show) 408 | 409 | 410 | def plot_summary(pupil_params, cr_params, output_dir=None, show=False, 411 | image_type=".png"): 412 | """Plot timeseries of various pupil and cr parameters. 413 | 414 | Generates plots of pupil and cr parameters against frame number. 415 | The plots include (x, y) position, angle, and (semi-minor, 416 | semi-major) axes seperately for pupil and cr, for a total of 6 417 | plots. 418 | 419 | Parameters 420 | ---------- 421 | pupil_params : numpy.ndarray 422 | Array of pupil parameters at every frame. 423 | cr_params : numpy.ndarray 424 | Array of cr parameters at every frame. 425 | output_dir : string 426 | Output directory for storing saved images of plots. 427 | show : bool 428 | Whether or not to call pyplot.show() after generating the plots. 429 | image_type : string 430 | File extension to use if saving images to `output_dir`. 431 | """ 432 | plot_timeseries(pupil_params.T[0], "pupil x", x2=pupil_params.T[1], 433 | label2="pupil y", title="pupil position", 434 | filename=get_filename(output_dir, "pupil_position", 435 | image_type), 436 | show=False) 437 | plot_timeseries(cr_params.T[0], "cr x", x2=cr_params.T[1], 438 | label2="pupil y", title="cr position", 439 | filename=get_filename(output_dir, "cr_position", 440 | image_type), 441 | show=False) 442 | plot_timeseries(pupil_params.T[3], "pupil axis1", x2=pupil_params.T[4], 443 | label2="pupil axis2", title="pupil major/minor axes", 444 | filename=get_filename(output_dir, "pupil_axes", 445 | image_type), 446 | show=False) 447 | plot_timeseries(cr_params.T[3], "cr axis1", x2=cr_params.T[4], 448 | label2="cr axis1", title="cr major/minor axes", 449 | filename=get_filename(output_dir, "cr_axes", 450 | image_type), 451 | show=False) 452 | plot_timeseries(pupil_params.T[2], "pupil angle", title="pupil angle", 453 | filename=get_filename(output_dir, "pupil_angle", 454 | image_type), 455 | show=False) 456 | plot_timeseries(cr_params.T[2], "cr angle", title="cr angle", 457 | filename=get_filename(output_dir, "cr_angle", 458 | image_type), 459 | show=show) 460 | 461 | 462 | def plot_timeseries(x1, label1, x2=None, label2=None, title=None, 463 | filename=None, show=False): 464 | """Helper function to plot up to 2 timeseries against index. 465 | 466 | Parameters 467 | ---------- 468 | x1 : numpy.ndarray 469 | Array of values to plot. 470 | label1 : string 471 | Label for `x1` timeseries. 472 | x2 : numpy.ndarray 473 | Optional second array of values to plot. 474 | label2 : string 475 | Label for `x2` timeseries. 476 | title : string 477 | Title for the plot. 478 | filename : string 479 | Filename to save the plot to. 480 | show : bool 481 | Whether or not to call pyplot.show() after generating the plot. 482 | """ 483 | fig, ax = plt.subplots(1) 484 | ax.plot(x1, label=label1) 485 | if x2 is not None: 486 | ax.plot(x2, label=label2) 487 | ax.set_xlabel('frame index') 488 | if title: 489 | ax.set_title(title) 490 | ax.legend() 491 | if filename is not None: 492 | fig.savefig(filename) 493 | if show: 494 | plt.show() 495 | 496 | 497 | def plot_density(density, title=None, filename=None, show=False): 498 | """Plot cumulative density. 499 | 500 | Parameters 501 | ---------- 502 | density : numpy.ndarray 503 | Accumulated 2-D density map to plot. 504 | title : string 505 | Title for the plot. 506 | filename : string 507 | Filename to save the plot to. 508 | show : bool 509 | Whether or not to call pyplot.show() after generating the plot. 510 | """ 511 | fig, ax = plt.subplots(1) 512 | ax.imshow(density, cmap="gray", interpolation="nearest") 513 | if title: 514 | ax.set_title(title) 515 | if filename is not None: 516 | fig.savefig(filename) 517 | if show: 518 | plt.show() 519 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/ransac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RansacFitter(object): 5 | def __init__(self): 6 | self.best_error = np.inf 7 | self.best_params = None 8 | 9 | def fit(self, fit_function, error_function, data, threshold, 10 | minimum_points_for_fit, number_of_close_points, iterations, 11 | **kwargs): 12 | """Find a best fit to a model using ransac. 13 | 14 | Parameters 15 | ---------- 16 | fit_function : callable 17 | Method that fits a model to `data`. 18 | error_function : callable 19 | Function that calculates error given `parameters` and a 20 | subset of `data`. Returns array of errors, one for each 21 | sample. 22 | data : numpy.ndarray 23 | Matrix of data points of shape (#samples, #items/sample). 24 | threshold : float 25 | Error threshold below which data should be considered an 26 | inlier. 27 | minimum_points_for_fit : int 28 | Number of points required to fit data. 29 | number_of_close_points : int 30 | Number of candidate outliers reselected as inliers required 31 | to consider a good fit. 32 | iterations : int 33 | Number of iterations to run. 34 | **kwargs 35 | Additional constraint keyword arguments passed to 36 | `fit_function`. 37 | 38 | Returns 39 | ------- 40 | best_params : numpy.ndarray 41 | Best parameters of the model. 42 | best_error : float 43 | Best error in the fitting. 44 | 45 | Raises 46 | ------ 47 | ValueError: 48 | If there is less data than `minimum_points_for_fit`. 49 | """ 50 | if data.shape[0] < minimum_points_for_fit: 51 | raise ValueError("Insufficient data for fit") 52 | self.best_error = np.inf 53 | self.best_params = None 54 | for i in range(iterations): 55 | parameters, error = fit_iteration(fit_function, error_function, 56 | data, threshold, 57 | minimum_points_for_fit, 58 | number_of_close_points, 59 | **kwargs) 60 | if error < self.best_error: 61 | self.best_params = parameters 62 | self.best_error = error 63 | return self.best_params, self.best_error 64 | 65 | 66 | def fit_iteration(fit_function, error_function, data, threshold, 67 | minimum_points_for_fit, number_of_close_points, 68 | **kwargs): 69 | """Perform one iteration of ransac model fitting. 70 | 71 | Parameters 72 | ---------- 73 | fit_function : callable 74 | Method that fits a model to `data`. 75 | error_function : callable 76 | Function that calculates error given `parameters` and a 77 | subset of `data`. Returns array of errors, one for each 78 | sample. 79 | data : numpy.ndarray 80 | Matrix of data points of shape (#samples, #items/sample). 81 | threshold : float 82 | Error threshold below which data should be considered an 83 | inlier. 84 | minimum_points_for_fit : int 85 | Number of points required to fit data. 86 | number_of_close_points : int 87 | Number of candidate outliers reselected as inliers required 88 | to consider a good fit. 89 | **kwargs 90 | Additional constraint keyword arguments passed to 91 | `fit_function`. 92 | 93 | Returns 94 | ------- 95 | tuple 96 | (model parameters, model error) 97 | """ 98 | inlier_idx, outlier_idx = partition_candidate_indices( 99 | data, minimum_points_for_fit) 100 | parameters, error = fit_function(data[inlier_idx, :], **kwargs) 101 | if parameters is not None: 102 | also_inlier_idx = check_outliers(error_function, parameters, data, 103 | outlier_idx, threshold) 104 | if len(also_inlier_idx) > number_of_close_points: 105 | idx = np.concatenate((inlier_idx, also_inlier_idx)) 106 | parameters, error = fit_function(data[idx, :], **kwargs) 107 | return parameters, error 108 | return None, np.inf 109 | 110 | 111 | def check_outliers(error_function, parameters, data, outlier_indices, 112 | threshold): 113 | """Check if any outliers should be inliers based on initial fit. 114 | 115 | Parameters 116 | ---------- 117 | error_function : callable 118 | Function that calculates error given `parameters` and a 119 | subset of `data`. Returns array of errors, one for each 120 | sample. 121 | parameters : numpy.ndarray 122 | Model parameters after some fit. 123 | data : numpy.ndarray 124 | Matrix of data points of shape (#samples, #items/sample). 125 | outlier_indices : numpy.ndarray 126 | Index array for initial outlier guess. 127 | threshold : float 128 | Error threshold below which data should be considered an 129 | inlier. 130 | 131 | Returns 132 | ------- 133 | numpy.ndarray 134 | Index array of new inliers. 135 | """ 136 | also_in_index = error_function(parameters, 137 | data[outlier_indices, :]) < threshold 138 | 139 | return outlier_indices[also_in_index] 140 | 141 | 142 | def partition_candidate_indices(data, minimum_points_for_fit): 143 | """Generation indices to partition data into inliers/outliers. 144 | 145 | Parameters 146 | ---------- 147 | data : np.ndarray 148 | Matrix of data points of shape (#samples, #items/sample). 149 | minimum_points_for_fit : int 150 | Minimum number of points required to attempt fit. 151 | 152 | Returns 153 | ------- 154 | tuple 155 | (inliers, outliers) tuple of index arrays for potential 156 | """ 157 | shuffled = np.random.permutation(np.arange(data.shape[0])) 158 | 159 | inliers = shuffled[:minimum_points_for_fit] 160 | outliers = shuffled[minimum_points_for_fit:] 161 | 162 | return inliers, outliers 163 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenInstitute/allensdk.eye_tracking/021cf0a023f02ec9f839fe606826fb6700c25a40/allensdk/eye_tracking/ui/__init__.py -------------------------------------------------------------------------------- /allensdk/eye_tracking/ui/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import json 4 | import logging 5 | from allensdk.eye_tracking.ui.qt import QtWidgets, ViewerWindow 6 | from allensdk.eye_tracking import _schemas 7 | 8 | 9 | def load_config(config_file): 10 | config = None 11 | try: 12 | with open(config_file, "r") as f: 13 | config = json.load(f) 14 | except Exception as e: 15 | logging.error(e) 16 | return config 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--profile", action="store_true") 22 | parser.add_argument("--config_file", type=str, default="") 23 | args, left = parser.parse_known_args() 24 | sys.argv = sys.argv[:1] + left 25 | config = None 26 | if args.config_file: 27 | config = load_config(args.config_file) 28 | app = QtWidgets.QApplication([]) 29 | w = ViewerWindow(_schemas.InputParameters, args.profile, config) 30 | w.show() 31 | sys.exit(app.exec_()) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/ui/qt.py: -------------------------------------------------------------------------------- 1 | from qtpy import QtCore, QtWidgets, QtGui 2 | from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg 3 | from matplotlib.figure import Figure, SubplotParams 4 | import ast 5 | import os 6 | import json 7 | import cProfile 8 | from argschema.schemas import mm 9 | from argschema import ArgSchemaParser 10 | from allensdk.eye_tracking import _schemas 11 | from allensdk.eye_tracking.frame_stream import CvInputStream 12 | from allensdk.eye_tracking.eye_tracking import EyeTracker 13 | from allensdk.eye_tracking.plotting import annotate_with_box 14 | 15 | 16 | LITERAL_EVAL_TYPES = {_schemas.NumpyArray, _schemas.Bool} 17 | 18 | 19 | class DropFileMixin(object): 20 | """Mixin for accepting drag and drop of a file.""" 21 | file_dropped = QtCore.Signal(str) 22 | 23 | def dragEnterEvent(self, event): 24 | if event.mimeData().hasUrls(): 25 | event.accept() 26 | else: 27 | event.ignore() 28 | 29 | def dragMoveEvent(self, event): 30 | if event.mimeData().hasUrls(): 31 | event.accept() 32 | else: 33 | event.ignore() 34 | 35 | def dropEvent(self, event): 36 | if event.mimeData().hasUrls(): 37 | event.accept() 38 | filename = str(event.mimeData().urls()[0].toLocalFile()) 39 | self.file_dropped.emit(filename) 40 | else: 41 | event.ignore() 42 | 43 | 44 | class FieldWidget(QtWidgets.QLineEdit): 45 | """Widget for displaying and editing a schema field. 46 | 47 | Parameters 48 | ---------- 49 | key : string 50 | Name of the field. 51 | field : argschema.Field 52 | Argschema Field object containing serialization and default 53 | data information. 54 | parent : QtWidgets.QWidget 55 | Parent widget. 56 | """ 57 | def __init__(self, key, field, parent=None, **kwargs): 58 | self.field = field 59 | super(FieldWidget, self).__init__(self._get_default_string(), 60 | parent=parent) 61 | self.key = key 62 | self.setEnabled(not kwargs.get("read_only", False)) 63 | self.displayed = kwargs.get("visible", True) 64 | self.setVisible(self.displayed) 65 | 66 | def _get_default_string(self): 67 | if self.field.default == mm.missing: 68 | default = "" 69 | else: 70 | default = str(self.field.default) 71 | 72 | return default 73 | 74 | def reset(self): 75 | self.setText(self._get_default_string()) 76 | 77 | def get_json(self): 78 | """Get the JSON serializable data from this field. 79 | 80 | Returns 81 | ------- 82 | data : object 83 | JSON serializable data in the widget, or None if empty. 84 | """ 85 | raw_value = str(self.text()) 86 | if raw_value: 87 | if type(self.field) in LITERAL_EVAL_TYPES: 88 | try: 89 | raw_value = ast.literal_eval(raw_value) 90 | except SyntaxError: 91 | pass # let validation handle it 92 | value = self.field.deserialize(raw_value) 93 | if isinstance(self.field, _schemas.NumpyArray): 94 | value = value.tolist() 95 | return value 96 | return None 97 | 98 | 99 | class SchemaWidget(QtWidgets.QWidget): 100 | """Widget for displaying an ArgSchema. 101 | 102 | Parameters 103 | ---------- 104 | key : string 105 | The key of the schema if it is nested. 106 | schema : argschema.DefaultSchema 107 | The schema to create a widget for. 108 | parent : QtWidgets.QWidget 109 | Parent widget. 110 | """ 111 | def __init__(self, key, schema, parent=None, config=None): 112 | super(SchemaWidget, self).__init__(parent=parent) 113 | self.key = key 114 | self.schema = schema 115 | self.fields = {} 116 | self.config = config 117 | if config is None: 118 | self.config = {} 119 | self.layout = QtWidgets.QGridLayout() 120 | all_children_hidden = self._init_widgets() 121 | self.setLayout(self.layout) 122 | self.displayed = not all_children_hidden 123 | self.setVisible(self.displayed) 124 | 125 | def _init_widgets(self): 126 | fields = {} 127 | nested = {} 128 | all_hidden = True 129 | for k, v in self.schema.fields.items(): 130 | if isinstance(v, _schemas.Nested): 131 | w = SchemaWidget(k, v.schema, self, 132 | config=self.config.get(k, {})) 133 | nested[k] = w 134 | else: 135 | w = FieldWidget(k, v, self, **self.config.get(k, {})) 136 | fields[k] = w 137 | self.fields[k] = w 138 | if w.displayed: 139 | all_hidden = False 140 | self._init_layout(fields, nested) 141 | return all_hidden 142 | 143 | def _init_layout(self, fields, nested): 144 | i = 0 145 | if self.key is not None: 146 | label = QtWidgets.QLabel("{}".format(self.key)) 147 | label.setAlignment(QtCore.Qt.AlignCenter) 148 | self.layout.addWidget(label, i, 0, 1, 2) 149 | i += 1 150 | for k, v in sorted(fields.items()): 151 | label = QtWidgets.QLabel("{}: ".format(k)) 152 | label.setVisible(v.displayed) 153 | self.layout.addWidget(label, i, 0) 154 | self.layout.addWidget(v, i, 1) 155 | i += 1 156 | for k, v in sorted(nested.items()): 157 | self.layout.addWidget(v, i, 0, 1, 2) 158 | i += 1 159 | 160 | def reset(self): 161 | for widget in self.fields.values(): 162 | widget.reset() 163 | 164 | def get_json(self): 165 | """Get the JSON serializable data from this schema. 166 | 167 | Returns 168 | ------- 169 | data : object 170 | JSON serializable data in the widget, or None if empty. 171 | """ 172 | json_data = {} 173 | for key, value in self.fields.items(): 174 | data = value.get_json() 175 | if data is not None: 176 | json_data[key] = data 177 | if json_data: 178 | return json_data 179 | return None 180 | 181 | def update_value(self, attribute, value): 182 | """Update a value in the schema. 183 | 184 | Parameters 185 | ---------- 186 | attribute : string 187 | Attribute name to update. 188 | value : string 189 | Value to set the field edit box to. 190 | """ 191 | attrs = attribute.split(".", 1) 192 | if len(attrs) > 1: 193 | self.fields[attrs[0]].update_value(attrs[1], value) 194 | else: 195 | self.fields[attribute].setText(value) 196 | 197 | 198 | class InputJsonWidget(QtWidgets.QScrollArea): 199 | """Widget for displaying an editable input json in a scroll area. 200 | 201 | Parameters 202 | ---------- 203 | schema : argschema.DefaultSchema 204 | Schema from which to build widgets. 205 | parent : QtWidgets.QWidget 206 | Parent widget. 207 | """ 208 | def __init__(self, schema, parent=None, config=None): 209 | super(InputJsonWidget, self).__init__(parent=parent) 210 | self.schema_widget = SchemaWidget(None, schema, self, config) 211 | self.setWidget(self.schema_widget) 212 | 213 | def get_json(self): 214 | return self.schema_widget.get_json() 215 | 216 | def update_value(self, attribute, value): 217 | self.schema_widget.update_value(attribute, value) 218 | 219 | def reset(self): 220 | self.schema_widget.reset() 221 | 222 | 223 | class BBoxCanvas(FigureCanvasQTAgg, DropFileMixin): 224 | """Matplotlib canvas widget with drawable box. 225 | 226 | Parameters 227 | ---------- 228 | figure : matplotlib.Figure 229 | Matplob figure to contain in the canvas. 230 | """ 231 | box_updated = QtCore.Signal(int, int, int, int) 232 | file_dropped = QtCore.Signal(str) 233 | 234 | def __init__(self, figure): 235 | super(BBoxCanvas, self).__init__(figure) 236 | self.setAcceptDrops(True) 237 | self._im_shape = None 238 | self.rgba = (255, 255, 255, 20) 239 | self.begin = QtCore.QPoint() 240 | self.end = QtCore.QPoint() 241 | self.drawing = False 242 | 243 | @property 244 | def im_shape(self): 245 | if self._im_shape is None: 246 | return (self.height(), self.width()) 247 | return self._im_shape 248 | 249 | @im_shape.setter 250 | def im_shape(self, value): 251 | self._im_shape = value 252 | 253 | def set_rgb(self, r, g, b): 254 | """Set the RGB values for the bounding box tool. 255 | 256 | Parameters 257 | ---------- 258 | r : int 259 | Red channel value (0-255). 260 | g : int 261 | Green channel value (0-255). 262 | b : int 263 | Blue channel value (0-255). 264 | """ 265 | self.rgba = (r, g, b, 20) 266 | 267 | def paintEvent(self, event): 268 | """Event override for painting to draw bounding box. 269 | 270 | Parameters 271 | ---------- 272 | event : QtCore.QEvent 273 | The paint event. 274 | """ 275 | super(BBoxCanvas, self).paintEvent(event) 276 | if self.drawing: 277 | painter = QtGui.QPainter(self) 278 | brush = QtGui.QBrush(QtGui.QColor(*self.rgba)) 279 | painter.setBrush(brush) 280 | painter.drawRect(QtCore.QRect(self.begin, self.end)) 281 | 282 | def wheelEvent(self, event): 283 | """Event override to stop crashing of wheelEvent in PyQt5. 284 | 285 | Parameters 286 | ---------- 287 | event : QtCore.QEvent 288 | The wheel event. 289 | """ 290 | event.ignore() 291 | 292 | def mousePressEvent(self, event): 293 | """Event override for painting to initialize bounding box. 294 | 295 | Parameters 296 | ---------- 297 | event : QtCore.QEvent 298 | The mouse press event. 299 | """ 300 | self.begin = event.pos() 301 | self.end = event.pos() 302 | self.drawing = True 303 | self.update() 304 | 305 | def mouseMoveEvent(self, event): 306 | """Event override for painting to update bounding box. 307 | 308 | Parameters 309 | ---------- 310 | event : QtCore.QEvent 311 | The mouse move event. 312 | """ 313 | self.end = event.pos() 314 | self.update() 315 | 316 | def _scale_and_offset(self): 317 | h, w = self.im_shape 318 | im_aspect = float(h) / w 319 | aspect = float(self.height()) / self.width() 320 | if aspect > im_aspect: 321 | # taller than image, empty space padding bottom and top 322 | scale = float(w) / self.width() 323 | wimage_height = self.height() * scale 324 | xoffset = 0 325 | yoffset = int((wimage_height - h) / 2.0) 326 | else: 327 | scale = float(h) / self.height() 328 | wimage_width = self.width() * scale 329 | xoffset = int((wimage_width - w) / 2.0) 330 | yoffset = 0 331 | return scale, xoffset, yoffset 332 | 333 | def mouseReleaseEvent(self, event): 334 | """Event override for painting to finalize bounding box. 335 | 336 | Parameters 337 | ---------- 338 | event : QtCore.QEvent 339 | The mouse release event. 340 | """ 341 | self.end = event.pos() 342 | self.update() 343 | self.drawing = False 344 | scale, xoffset, yoffset = self._scale_and_offset() 345 | x1 = int(self.begin.x() * scale) - xoffset 346 | x2 = int(self.end.x() * scale) - xoffset 347 | y1 = int(self.begin.y() * scale) - yoffset 348 | y2 = int(self.end.y() * scale) - yoffset 349 | self.box_updated.emit(max(min(x1, x2), 1), 350 | min(max(x1, x2), self.im_shape[1] - 1), 351 | max(min(y1, y2), 1), 352 | min(max(y1, y2), self.im_shape[0] - 1)) 353 | 354 | 355 | class ViewerWidget(QtWidgets.QWidget): 356 | """Widget for tweaking eye tracking parameters and viewing output. 357 | 358 | Parameters 359 | ---------- 360 | schema_type : type(argschema.DefaultSchema) 361 | The input schema type. 362 | """ 363 | def __init__(self, schema_type, profile_runs=False, parent=None, 364 | config=None): 365 | super(ViewerWidget, self).__init__(parent=parent) 366 | self.profile_runs = profile_runs 367 | self.layout = QtWidgets.QGridLayout() 368 | self.config = config 369 | if config is None: 370 | self.config = {} 371 | self.schema_type = schema_type 372 | self.video = "./" 373 | self._init_widgets() 374 | self.tracker = EyeTracker(None, None) 375 | self.update_tracker() 376 | self.setLayout(self.layout) 377 | 378 | def _init_widgets(self): 379 | sp_params = SubplotParams(0, 0, 1, 1) 380 | self.figure = Figure(frameon=False, subplotpars=sp_params) 381 | self.axes = self.figure.add_subplot(111) 382 | self.canvas = BBoxCanvas(self.figure) 383 | self.json_view = InputJsonWidget( 384 | self.schema_type(), parent=self, 385 | config=self.config.get("input_json", {})) 386 | self.rerun_button = QtWidgets.QPushButton("Reprocess Frame", 387 | parent=self) 388 | self.pupil_radio = QtWidgets.QRadioButton("Pupil BBox", parent=self) 389 | self.cr_radio = QtWidgets.QRadioButton("CR BBox", parent=self) 390 | self.slider = QtWidgets.QSlider(parent=self) 391 | self.slider.setMinimum(0) 392 | self.slider.setOrientation(QtCore.Qt.Horizontal) 393 | self._connect_signals() 394 | self._init_layout() 395 | 396 | def _init_layout(self): 397 | self.layout.addWidget(self.canvas, 0, 0, 1, 2) 398 | self.layout.addWidget(self.json_view, 0, 2, 1, 2) 399 | self.layout.addWidget(self.slider, 1, 0, 1, 2) 400 | self.layout.addWidget(self.rerun_button, 2, 0) 401 | self.layout.addWidget(self.pupil_radio, 2, 1) 402 | self.layout.addWidget(self.cr_radio, 2, 2) 403 | 404 | def _connect_signals(self): 405 | self.slider.sliderReleased.connect(self.show_frame) 406 | self.rerun_button.clicked.connect(self.update_tracker) 407 | self.canvas.box_updated.connect(self.update_bbox) 408 | self.canvas.file_dropped.connect(self.load_video) 409 | self.pupil_radio.clicked.connect(self._setup_bbox) 410 | self.cr_radio.clicked.connect(self._setup_bbox) 411 | 412 | def _setup_bbox(self): 413 | if self.pupil_radio.isChecked(): 414 | self.canvas.set_rgb(0, 0, 255) 415 | elif self.cr_radio.isChecked(): 416 | self.canvas.set_rgb(255, 0, 0) 417 | 418 | def update_bbox(self, xmin, xmax, ymin, ymax): 419 | bbox = [xmin, xmax, ymin, ymax] 420 | if self.pupil_radio.isChecked(): 421 | self.json_view.update_value("pupil_bounding_box", str(bbox)) 422 | self.update_tracker() 423 | elif self.cr_radio.isChecked(): 424 | self.json_view.update_value("cr_bounding_box", str(bbox)) 425 | self.update_tracker() 426 | 427 | def _parse_args(self, json_data): 428 | try: 429 | mod = ArgSchemaParser(input_data=json_data, 430 | schema_type=self.schema_type) 431 | return mod.args 432 | except Exception as e: 433 | self._json_error_popup(e) 434 | return None 435 | 436 | def get_json_data(self): 437 | try: 438 | return self.json_view.get_json() 439 | except Exception as e: 440 | self._json_error_popup(e) 441 | 442 | def update_tracker(self): 443 | json_data = self.get_json_data() 444 | if json_data is None: 445 | return 446 | input_source = os.path.normpath(json_data.get("input_source", "./")) 447 | load = False 448 | if not os.path.isfile(input_source): 449 | json_data["input_source"] = os.path.abspath(__file__) 450 | elif self.video != input_source: 451 | load = True 452 | args = self._parse_args(json_data) 453 | if args: 454 | self.tracker.update_fit_parameters( 455 | args["starburst"], args["ransac"], 456 | args["pupil_bounding_box"], args["cr_bounding_box"], 457 | **args["eye_params"]) 458 | if load: 459 | self._load_video(input_source) 460 | elif self.tracker.input_stream is not None: 461 | self.show_frame() 462 | 463 | def save_json(self): 464 | json_data = self.get_json_data() 465 | if json_data is None: 466 | return 467 | valid = self._parse_args(json_data) 468 | if valid: 469 | path = self.config.get("json_save_path", "./") 470 | base, _ = os.path.splitext( 471 | os.path.basename(json_data["input_source"])) 472 | default_filename = os.path.join(path, base + ".json") 473 | filepath, _ = QtWidgets.QFileDialog.getSaveFileName( 474 | self, "Json file", default_filename) 475 | if os.path.exists(os.path.dirname(filepath)): 476 | with open(filepath, "w") as f: 477 | json.dump(json_data, f, indent=1) 478 | 479 | def load_video(self, filepath=None): 480 | if filepath is None: 481 | filepath, _ = QtWidgets.QFileDialog.getOpenFileName( 482 | self, "Select video") 483 | filepath = filepath.strip("'\" ") 484 | if os.path.exists(filepath): 485 | self.json_view.reset() 486 | self.json_view.update_value("input_source", 487 | os.path.normpath(filepath)) 488 | self.update_tracker() 489 | 490 | def _load_video(self, path): 491 | self.video = os.path.normpath(path) 492 | input_stream = CvInputStream(self.video) 493 | self.tracker.input_stream = input_stream 494 | self.slider.setMaximum(input_stream.num_frames-1) 495 | self.slider.setValue(0) 496 | self.show_frame() 497 | 498 | def show_frame(self, n=None): 499 | self.axes.clear() 500 | frame = self.tracker.input_stream[self.slider.value()] 501 | self.canvas.im_shape = self.tracker.im_shape 502 | if self.profile_runs: 503 | p = cProfile.Profile() 504 | p.enable() 505 | self.tracker.last_pupil_color = self.tracker.min_pupil_value 506 | cr, pupil, cr_err, pupil_err = self.tracker.process_image(frame) 507 | anno = self.tracker.annotator.annotate_frame( 508 | self.tracker.current_image, pupil, cr, self.tracker.current_seed, 509 | self.tracker.current_pupil_candidates) 510 | self.tracker.annotator.clear_rc() 511 | if self.profile_runs: 512 | p.disable() 513 | p.print_stats('cumulative') 514 | anno = annotate_with_box(anno, self.tracker.cr_bounding_box, 515 | (0, 0, 255)) 516 | anno = annotate_with_box(anno, self.tracker.pupil_bounding_box, 517 | (255, 0, 0)) 518 | self.axes.imshow(anno[:, :, ::-1], interpolation="none") 519 | self.axes.axis("off") 520 | self.canvas.draw() 521 | 522 | def _json_error_popup(self, msg): 523 | message = "
Error parsing input json: \n{}
".format(msg) 524 | box = QtWidgets.QMessageBox(self) 525 | box.setText(message) 526 | box.exec_() 527 | 528 | 529 | class ViewerWindow(QtWidgets.QMainWindow): 530 | def __init__(self, schema_type, profile_runs=False, config=None): 531 | super(ViewerWindow, self).__init__() 532 | self.setWindowTitle("Eye Tracking Configuration Tool") 533 | self.widget = ViewerWidget(schema_type, profile_runs=profile_runs, 534 | parent=self, config=config) 535 | self.setCentralWidget(self.widget) 536 | self._init_menubar() 537 | 538 | def _init_menubar(self): 539 | file_menu = self.menuBar().addMenu("&File") 540 | load = file_menu.addAction("Load Video") 541 | save = file_menu.addAction("Save JSON") 542 | load.triggered.connect(self.widget.load_video) 543 | save.triggered.connect(self.widget.save_json) 544 | -------------------------------------------------------------------------------- /allensdk/eye_tracking/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def good_coordinate_mask(xs, ys, shape): 5 | """Generate a coordinate mask inside `shape`. 6 | 7 | Parameters 8 | ---------- 9 | xs : np.ndarray 10 | X indices. 11 | ys : np.ndarray 12 | Y indices. 13 | shape : tuple 14 | (height, width) shape of image. 15 | 16 | Returns 17 | ------- 18 | np.ndarray 19 | Logical mask for `xs` and `xs` to ensure they are inside image. 20 | """ 21 | return np.logical_and(np.logical_and(ys >= 0, ys < shape[0]), 22 | np.logical_and(xs >= 0, xs < shape[1])) 23 | 24 | 25 | def get_ray_values(xs, ys, image): 26 | """Get values of image along a set of rays. 27 | 28 | Parameters 29 | ---------- 30 | xs : np.ndarray 31 | X indices of rays. 32 | ys : np.ndarray 33 | Y indices of rays. 34 | image : np.ndarray 35 | Image to get values from. 36 | 37 | Returns 38 | ------- 39 | list 40 | List of arrays of image values along each ray. 41 | """ 42 | ray_values = [] 43 | for i in range(xs.shape[0]): 44 | mask = good_coordinate_mask(xs[i], ys[i], image.shape) 45 | xm = xs[i][mask] 46 | ym = ys[i][mask] 47 | ray_values.append(image[ym, xm]) 48 | 49 | return ray_values 50 | 51 | 52 | def generate_ray_indices(index_length, n_rays): 53 | """Generate index arrays for rays emanating in a circle from a point. 54 | 55 | Rays have start point at 0,0. 56 | 57 | Parameters 58 | ---------- 59 | index_length : int 60 | Length of each index array. 61 | n_rays : int 62 | Number of rays to generate. Rays are evenly distributed about 63 | 360 degrees. 64 | 65 | Returns 66 | ------- 67 | tuple 68 | (xs, ys) tuple of [n_rays,index_length] index matrices. 69 | """ 70 | angles = (np.arange(n_rays)*2.0*np.pi/n_rays).reshape(n_rays, 1) 71 | xs = np.arange(index_length).reshape(1, index_length) 72 | ys = np.zeros((1, index_length)) 73 | 74 | return rotate_rays(xs, ys, angles) 75 | 76 | 77 | def rotate_rays(xs, ys, angles): 78 | """Rotate index arrays about angles. 79 | 80 | Parameters 81 | ---------- 82 | xs : np.ndarray 83 | Unrotated x-index array of shape [1,n]. 84 | ys : np.ndarray 85 | Unrotated y-index array of shape [1,n]. 86 | angles : np.adarray 87 | Angles over which to rotate of shape [m,1]. 88 | 89 | Returns 90 | ------- 91 | tuple 92 | (xs, ys) tuple of [m,n] index matrices. 93 | """ 94 | cosines = np.cos(angles) 95 | sines = np.sin(angles) 96 | x_rot = np.dot(cosines, xs) + np.dot(sines, ys) 97 | y_rot = np.dot(cosines, ys) - np.dot(sines, xs) 98 | 99 | return x_rot.astype(np.int64), y_rot.astype(np.int64) 100 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | build: false 2 | 3 | environment: 4 | matrix: 5 | - MINICONDA: "C:\\Miniconda-x64" 6 | PYTHON: 2.7 7 | - MINICONDA: "C:\\Miniconda36-x64" 8 | PYTHON: 3.6 9 | 10 | install: 11 | - set PATH=%MINICONDA%;%MINICONDA%\\Scripts;%PATH% 12 | - conda config --set always_yes yes --set changeps1 no 13 | - conda update -q conda 14 | - conda create -q -n test-environment python=%PYTHON% pip 15 | - activate test-environment 16 | - if %PYTHON% == 2.7 conda install scikit-image 17 | - conda install -c conda-forge opencv 18 | - conda install -c conda-forge pyqt 19 | - pip install -r test_requirements.txt 20 | - pip install . 21 | 22 | test_script: 23 | - pytest -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/eye_tracking.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/eye_tracking.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/eye_tracking" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/eye_tracking" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /docs/_static/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AllenInstitute/allensdk.eye_tracking/021cf0a023f02ec9f839fe606826fb6700c25a40/docs/_static/.gitkeep -------------------------------------------------------------------------------- /docs/allensdk.eye_tracking.rst: -------------------------------------------------------------------------------- 1 | allensdk\.eye\_tracking package 2 | =============================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | allensdk.eye_tracking.ui 10 | 11 | Submodules 12 | ---------- 13 | 14 | allensdk\.eye\_tracking\.eye\_tracking module 15 | --------------------------------------------- 16 | 17 | .. automodule:: allensdk.eye_tracking.eye_tracking 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | allensdk\.eye\_tracking\.feature\_extraction module 23 | --------------------------------------------------- 24 | 25 | .. automodule:: allensdk.eye_tracking.feature_extraction 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | allensdk\.eye\_tracking\.fit\_ellipse module 31 | -------------------------------------------- 32 | 33 | .. automodule:: allensdk.eye_tracking.fit_ellipse 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | allensdk\.eye\_tracking\.frame\_stream module 39 | --------------------------------------------- 40 | 41 | .. automodule:: allensdk.eye_tracking.frame_stream 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | allensdk\.eye\_tracking\.plotting module 47 | ---------------------------------------- 48 | 49 | .. automodule:: allensdk.eye_tracking.plotting 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | allensdk\.eye\_tracking\.ransac module 55 | -------------------------------------- 56 | 57 | .. automodule:: allensdk.eye_tracking.ransac 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | allensdk\.eye\_tracking\.utils module 63 | ------------------------------------- 64 | 65 | .. automodule:: allensdk.eye_tracking.utils 66 | :members: 67 | :undoc-members: 68 | :show-inheritance: 69 | 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: allensdk.eye_tracking 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /docs/allensdk.eye_tracking.ui.rst: -------------------------------------------------------------------------------- 1 | allensdk\.eye\_tracking\.ui package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | allensdk\.eye\_tracking\.ui\.qt module 8 | -------------------------------------- 9 | 10 | .. automodule:: allensdk.eye_tracking.ui.qt 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: -------------------------------------------------------------------------------- /docs/allensdk.rst: -------------------------------------------------------------------------------- 1 | allensdk package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | allensdk.eye_tracking 10 | 11 | Module contents 12 | --------------- 13 | 14 | .. automodule:: allensdk 15 | :members: 16 | :undoc-members: 17 | :show-inheritance: 18 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # eye_tracking documentation build configuration file, created by 5 | # sphinx-quickstart on Tue Jul 9 22:26:36 2013. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | import sys 17 | import os 18 | 19 | # If extensions (or modules to document with autodoc) are in another 20 | # directory, add these directories to sys.path here. If the directory is 21 | # relative to the documentation root, use os.path.abspath to make it 22 | # absolute, like shown here. 23 | #sys.path.insert(0, os.path.abspath('.')) 24 | 25 | # Get the project root dir, which is the parent dir of this 26 | cwd = os.getcwd() 27 | project_root = os.path.dirname(cwd) 28 | 29 | # Insert the project root dir as the first element in the PYTHONPATH. 30 | # This lets us ensure that the source package is imported, and that its 31 | # version is used. 32 | sys.path.insert(0, project_root) 33 | 34 | import allensdk.eye_tracking 35 | 36 | # -- General configuration --------------------------------------------- 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | #needs_sphinx = '1.0' 40 | 41 | # Add any Sphinx extension module names here, as strings. They can be 42 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 43 | extensions = ['sphinx.ext.autodoc', 44 | 'sphinx.ext.viewcode', 45 | 'sphinx.ext.napoleon', 46 | 'sphinx.ext.githubpages'] 47 | 48 | # Add any paths that contain templates here, relative to this directory. 49 | templates_path = ['_templates'] 50 | 51 | # The suffix of source filenames. 52 | source_suffix = '.rst' 53 | 54 | # The encoding of source files. 55 | #source_encoding = 'utf-8-sig' 56 | 57 | # The master toctree document. 58 | master_doc = 'index' 59 | 60 | # General information about the project. 61 | project = u'AllenSDK Eye Tracking' 62 | copyright = u"2018. The Allen Institute." 63 | 64 | # The version info for the project you're documenting, acts as replacement 65 | # for |version| and |release|, also used in various other places throughout 66 | # the built documents. 67 | # 68 | # The short X.Y version. 69 | version = allensdk.eye_tracking.__version__ 70 | # The full version, including alpha/beta/rc tags. 71 | release = allensdk.eye_tracking.__version__ 72 | 73 | # The language for content autogenerated by Sphinx. Refer to documentation 74 | # for a list of supported languages. 75 | #language = None 76 | 77 | # There are two options for replacing |today|: either, you set today to 78 | # some non-false value, then it is used: 79 | #today = '' 80 | # Else, today_fmt is used as the format for a strftime call. 81 | #today_fmt = '%B %d, %Y' 82 | 83 | # List of patterns, relative to source directory, that match files and 84 | # directories to ignore when looking for source files. 85 | exclude_patterns = ['_build'] 86 | 87 | # The reST default role (used for this markup: `text`) to use for all 88 | # documents. 89 | #default_role = None 90 | 91 | # If true, '()' will be appended to :func: etc. cross-reference text. 92 | #add_function_parentheses = True 93 | 94 | # If true, the current module name will be prepended to all description 95 | # unit titles (such as .. function::). 96 | #add_module_names = True 97 | 98 | # If true, sectionauthor and moduleauthor directives will be shown in the 99 | # output. They are ignored by default. 100 | #show_authors = False 101 | 102 | # The name of the Pygments (syntax highlighting) style to use. 103 | pygments_style = 'sphinx' 104 | 105 | # A list of ignored prefixes for module index sorting. 106 | #modindex_common_prefix = [] 107 | 108 | # If true, keep warnings as "system message" paragraphs in the built 109 | # documents. 110 | #keep_warnings = False 111 | 112 | 113 | # -- Options for HTML output ------------------------------------------- 114 | 115 | # The theme to use for HTML and HTML Help pages. See the documentation for 116 | # a list of builtin themes. 117 | 118 | import sphinx_rtd_theme 119 | html_theme = "sphinx_rtd_theme" 120 | 121 | # Theme options are theme-specific and customize the look and feel of a 122 | # theme further. For a list of options available for each theme, see the 123 | # documentation. 124 | #html_theme_options = {} 125 | 126 | # Add any paths that contain custom themes here, relative to this directory. 127 | #html_theme_path = [] 128 | 129 | # The name for this set of Sphinx documents. If None, it defaults to 130 | # " v documentation". 131 | #html_title = None 132 | 133 | # A shorter title for the navigation bar. Default is the same as 134 | # html_title. 135 | #html_short_title = None 136 | 137 | # The name of an image file (relative to this directory) to place at the 138 | # top of the sidebar. 139 | #html_logo = None 140 | 141 | # The name of an image file (within the static path) to use as favicon 142 | # of the docs. This file should be a Windows icon file (.ico) being 143 | # 16x16 or 32x32 pixels large. 144 | #html_favicon = None 145 | 146 | # Add any paths that contain custom static files (such as style sheets) 147 | # here, relative to this directory. They are copied after the builtin 148 | # static files, so a file named "default.css" will overwrite the builtin 149 | # "default.css". 150 | html_static_path = ['_static'] 151 | 152 | # If not '', a 'Last updated on:' timestamp is inserted at every page 153 | # bottom, using the given strftime format. 154 | #html_last_updated_fmt = '%b %d, %Y' 155 | 156 | # If true, SmartyPants will be used to convert quotes and dashes to 157 | # typographically correct entities. 158 | #html_use_smartypants = True 159 | 160 | # Custom sidebar templates, maps document names to template names. 161 | #html_sidebars = {} 162 | 163 | # Additional templates that should be rendered to pages, maps page names 164 | # to template names. 165 | #html_additional_pages = {} 166 | 167 | # If false, no module index is generated. 168 | #html_domain_indices = True 169 | 170 | # If false, no index is generated. 171 | #html_use_index = True 172 | 173 | # If true, the index is split into individual pages for each letter. 174 | #html_split_index = False 175 | 176 | # If true, links to the reST sources are added to the pages. 177 | #html_show_sourcelink = True 178 | 179 | # If true, "Created using Sphinx" is shown in the HTML footer. 180 | # Default is True. 181 | #html_show_sphinx = True 182 | 183 | # If true, "(C) Copyright ..." is shown in the HTML footer. 184 | # Default is True. 185 | #html_show_copyright = True 186 | 187 | # If true, an OpenSearch description file will be output, and all pages 188 | # will contain a tag referring to it. The value of this option 189 | # must be the base URL from which the finished HTML is served. 190 | #html_use_opensearch = '' 191 | 192 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 193 | #html_file_suffix = None 194 | 195 | # Output file base name for HTML help builder. 196 | htmlhelp_basename = 'eye_trackingdoc' 197 | 198 | 199 | # -- Options for LaTeX output ------------------------------------------ 200 | 201 | latex_elements = { 202 | # The paper size ('letterpaper' or 'a4paper'). 203 | #'papersize': 'letterpaper', 204 | 205 | # The font size ('10pt', '11pt' or '12pt'). 206 | #'pointsize': '10pt', 207 | 208 | # Additional stuff for the LaTeX preamble. 209 | #'preamble': '', 210 | } 211 | 212 | # Grouping the document tree into LaTeX files. List of tuples 213 | # (source start file, target name, title, author, documentclass 214 | # [howto/manual]). 215 | latex_documents = [ 216 | ('index', 'eye_tracking.tex', 217 | u'AllenSDK Eye Tracking Documentation', 218 | u'Jed Perkins', 'manual'), 219 | ] 220 | 221 | # The name of an image file (relative to this directory) to place at 222 | # the top of the title page. 223 | #latex_logo = None 224 | 225 | # For "manual" documents, if this is true, then toplevel headings 226 | # are parts, not chapters. 227 | #latex_use_parts = False 228 | 229 | # If true, show page references after internal links. 230 | #latex_show_pagerefs = False 231 | 232 | # If true, show URL addresses after external links. 233 | #latex_show_urls = False 234 | 235 | # Documents to append as an appendix to all manuals. 236 | #latex_appendices = [] 237 | 238 | # If false, no module index is generated. 239 | #latex_domain_indices = True 240 | 241 | 242 | # -- Options for manual page output ------------------------------------ 243 | 244 | # One entry per manual page. List of tuples 245 | # (source start file, name, description, authors, manual section). 246 | man_pages = [ 247 | ('index', 'allensdk.eye_tracking', 248 | u'AllenSDK Eye Tracking Documentation', 249 | [u'Jed Perkins'], 1) 250 | ] 251 | 252 | # If true, show URL addresses after external links. 253 | #man_show_urls = False 254 | 255 | 256 | # -- Options for Texinfo output ---------------------------------------- 257 | 258 | # Grouping the document tree into Texinfo files. List of tuples 259 | # (source start file, target name, title, author, 260 | # dir menu entry, description, category) 261 | texinfo_documents = [ 262 | ('index', 'allensdk.eye_tracking', 263 | u'AllenSDK Eye Tracking Documentation', 264 | u'Jed Perkins', 265 | 'allensdk.eye_tracking', 266 | 'Allen Institute package for mouse eye tracking.', 267 | 'Miscellaneous'), 268 | ] 269 | 270 | # Documents to append as an appendix to all manuals. 271 | #texinfo_appendices = [] 272 | 273 | # If false, no module index is generated. 274 | #texinfo_domain_indices = True 275 | 276 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 277 | #texinfo_show_urls = 'footnote' 278 | 279 | # If true, do not generate a @detailmenu in the "Top" node's menu. 280 | #texinfo_no_detailmenu = False 281 | -------------------------------------------------------------------------------- /docs/history.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../HISTORY.rst 2 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | AllenSDK Eye Tracking documentation 2 | =================================== 3 | 4 | Contents: 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | installation 10 | usage 11 | authors 12 | history 13 | modules 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | 23 | License 24 | ------- 25 | 26 | `The project is licensed under the BSD Clause 2 license with a non-commercial use clause. 27 | `_ -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Installation 3 | ============ 4 | 5 | The video IO is done using OpenCV's video functionality. Unfortunately, 6 | OpenCV on pip seems to not be built with the necessary backend, as the 7 | methods fail silently. As a result, we have not included OpenCV in the 8 | requirements and it is necessary to get it seperately, built with the 9 | video capture and writing functional. Additionally, on some platforms 10 | scikit-image does not build easily from source and the developers don't 11 | have bindary distributions for all platforms yet. The simplest way to 12 | install these difficult dependencies is to use conda:: 13 | 14 | conda install scikit-image 15 | conda install -c conda-forge opencv=3.3.0 16 | conda install -c conda-forge pyqt 17 | 18 | Pinning the version of opencv does not seem to be required for Windows, 19 | but is for linux, as the latest seems to have a bug with VideoCapture. 20 | The rest of the dependencies are all in the requirements, so to install 21 | just clone or download the repository and then from inside the top 22 | level directory either run:: 23 | 24 | pip install . 25 | 26 | or:: 27 | 28 | python setup.py install 29 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\eye_tracking.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\eye_tracking.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | allensdk 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | allensdk 8 | -------------------------------------------------------------------------------- /docs/usage.rst: -------------------------------------------------------------------------------- 1 | ===== 2 | Usage 3 | ===== 4 | 5 | After installing the package, and entry point is created so it can be run 6 | from the command line. To minimally run with the default settings:: 7 | 8 | allensdk.eye_tracking --input_source 9 | 10 | To see all options that can be set at the command line:: 11 | 12 | allensdk.eye_tracking --help 13 | 14 | There are a lot of options that can be set, so often it can be more 15 | convenient to store them in a JSON-formatted file which can be used like:: 16 | 17 | allensdk.eye_tracking --input_json 18 | 19 | The input json can be combined with other command line argument, which will 20 | take precedence over anything in the json. There is a UI tool for adjusting 21 | and saving input parameters that can be used by running:: 22 | 23 | allensdk.eye_tracking_ui -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argschema>=1.16.0 2 | numpy 3 | scipy 4 | scikit-image>=0.13.1 5 | matplotlib 6 | qtpy 7 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pip>=8.1.2 2 | bumpversion>=0.5.3 3 | wheel>=0.29.0 4 | watchdog>=0.8.3 5 | flake8>=2.6.0 6 | tox>=2.3.1 7 | coverage>=4.1 8 | Sphinx>=1.4.8 9 | PyYAML>=3.11 10 | pytest>=2.9.2 11 | sphinx_rtd_theme 12 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 1.2.1 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:'{package_name}'/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('requirements.txt', 'r') as f: 4 | requirements = f.read().splitlines() 5 | 6 | with open('test_requirements.txt', 'r') as f: 7 | test_requirements = f.read().splitlines() 8 | 9 | setup( 10 | name='allensdk_eye_tracking', 11 | version='1.2.1', 12 | description="""Allen Institute package for mouse eye tracking.""", 13 | author="Jed Perkins", 14 | author_email="jedp@alleninstitute.org", 15 | url='https://github.com/AllenInstitute/allensdk.eye_tracking', 16 | packages=find_packages(), 17 | include_package_data=True, 18 | install_requires=requirements, 19 | entry_points={ 20 | 'console_scripts': [ 21 | 'allensdk.eye_tracking = allensdk.eye_tracking.__main__:main', 22 | 'allensdk.eye_tracking_ui = allensdk.eye_tracking.ui.__main__:main' 23 | ] 24 | }, 25 | setup_requires=['pytest-runner'], 26 | tests_require=test_requirements 27 | ) 28 | -------------------------------------------------------------------------------- /test/test_eye_tracking.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import eye_tracking as et 2 | from skimage.draw import circle 3 | import numpy as np 4 | import pytest 5 | from itertools import islice 6 | from mock import patch, MagicMock, ANY 7 | 8 | 9 | def image(shape=(200, 200), cr_radius=10, cr_center=(100, 100), 10 | pupil_radius=30, pupil_center=(100, 100)): 11 | im = np.ones(shape, dtype=np.uint8)*128 12 | r, c = circle(pupil_center[0], pupil_center[1], pupil_radius, shape) 13 | im[r, c] = 0 14 | r, c = circle(cr_center[0], cr_center[1], cr_radius, shape) 15 | im[r, c] = 255 16 | return im 17 | 18 | 19 | class InputStream(object): 20 | def __init__(self, n_frames=5): 21 | self.n_frames = n_frames 22 | self.frame_shape = image().shape 23 | self._i = 0 24 | 25 | def __getitem__(self, key): 26 | return islice(self, key.start, key.stop, key.step) 27 | 28 | def __iter__(self): 29 | self._i = 0 30 | return self 31 | 32 | def __next__(self): 33 | if self._i >= self.n_frames: 34 | raise StopIteration() 35 | self._i += 1 36 | return image() 37 | 38 | def next(self): 39 | return self.__next__() 40 | 41 | @property 42 | def num_frames(self): 43 | return self.n_frames 44 | 45 | def close(self): 46 | pass 47 | 48 | 49 | class OutputStream(object): 50 | def __init__(self, shape): 51 | self.shape = (shape[0], shape[1], 3) 52 | self.closed = False 53 | 54 | def write(self, array): 55 | assert(array.shape == self.shape) 56 | 57 | def close(self): 58 | self.closed = True 59 | 60 | 61 | def test_invalid_point_type(): 62 | pg = et.PointGenerator(100, 10, 1.5, 1.5, 10, 10) 63 | values = np.arange(50, dtype=np.uint8) 64 | with pytest.raises(ValueError): 65 | pg.threshold_crossing(pg.xs[0], pg.ys[0], values, "blah") 66 | 67 | 68 | @pytest.mark.parametrize(("threshold_factor,threshold_pixels,point_type,ray," 69 | "raises"), [ 70 | (1.5, 10, "pupil", 0, False), 71 | (5, 20, "cr", 5, False), 72 | (3, 40, "pupil", 0, True) 73 | ]) 74 | def test_threshold_crossing(threshold_factor, threshold_pixels, point_type, 75 | ray, raises): 76 | pg = et.PointGenerator(100, 10, threshold_factor, threshold_factor, 77 | threshold_pixels, threshold_pixels) 78 | values = np.arange(50, dtype=np.uint8) 79 | if raises: 80 | with pytest.raises(ValueError): 81 | pg.threshold_crossing(pg.xs[ray], pg.ys[ray], values, point_type) 82 | else: 83 | t = pg.get_threshold(values, pg.threshold_pixels[point_type], 84 | pg.threshold_factor[point_type]) 85 | y, x = pg.threshold_crossing(pg.xs[ray], pg.ys[ray], values, 86 | point_type) 87 | idx = np.where(pg.xs[ray] == x) 88 | if pg.above_threshold[point_type]: 89 | assert(idx == np.argmax(values[threshold_pixels:] > t) + 90 | threshold_pixels) 91 | else: 92 | assert(idx == np.argmax(values[threshold_pixels:] < t) + 93 | threshold_pixels) 94 | 95 | 96 | @pytest.mark.parametrize("image,seed,above", [ 97 | (image(), (100, 100), True), 98 | (image(), (100, 100), False) 99 | ]) 100 | def test_get_candidate_points(image, seed, above): 101 | pg = et.PointGenerator(100, 10, 1, 10) 102 | pg.get_candidate_points(image, seed, above) 103 | 104 | 105 | @pytest.mark.parametrize(("input_stream,output_stream," 106 | "starburst_params,ransac_params,pupil_bounding_box," 107 | "cr_bounding_box,kwargs"), [ 108 | (None, 109 | None, 110 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 111 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 112 | "index_length": 100}, 113 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 114 | "number_of_close_points": 3}, 115 | None, 116 | None, 117 | {}), 118 | (None, 119 | None, 120 | None, 121 | None, 122 | [20, 40, 20, 40], 123 | [20, 40, 20, 40], 124 | {}) 125 | ]) 126 | def test_eye_tracker_init(input_stream, output_stream, starburst_params, 127 | ransac_params, pupil_bounding_box, 128 | cr_bounding_box, kwargs): 129 | tracker = et.EyeTracker(input_stream, output_stream, 130 | starburst_params, ransac_params, 131 | pupil_bounding_box, cr_bounding_box, **kwargs) 132 | assert(tracker.im_shape is None) 133 | if pupil_bounding_box is None: 134 | test_pupil_bbox = et.default_bounding_box(None) 135 | else: 136 | test_pupil_bbox = pupil_bounding_box 137 | if cr_bounding_box is None: 138 | test_cr_bbox = et.default_bounding_box(None) 139 | else: 140 | test_cr_bbox = cr_bounding_box 141 | assert(np.all(tracker.pupil_bounding_box == test_pupil_bbox)) 142 | assert(np.all(tracker.cr_bounding_box == test_cr_bbox)) 143 | assert(input_stream == tracker.input_stream) 144 | if output_stream is None: 145 | assert(tracker.annotator.output_stream is None) 146 | else: 147 | assert(tracker.annotator.output_stream is not None) 148 | 149 | 150 | @pytest.mark.parametrize(("input_stream,output_stream," 151 | "starburst_params,ransac_params,pupil_bounding_box," 152 | "cr_bounding_box,kwargs"), [ 153 | (InputStream(), 154 | None, 155 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 156 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 157 | "index_length": 100}, 158 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 159 | "number_of_close_points": 3}, 160 | None, 161 | None, 162 | {}), 163 | (None, 164 | None, 165 | None, 166 | None, 167 | [20, 40, 20, 40], 168 | [20, 40, 20, 40], 169 | {}) 170 | ]) 171 | def test_update_fit_parameters(input_stream, output_stream, starburst_params, 172 | ransac_params, pupil_bounding_box, 173 | cr_bounding_box, kwargs): 174 | tracker = et.EyeTracker(input_stream) 175 | tracker.update_fit_parameters(starburst_params, ransac_params, 176 | pupil_bounding_box, cr_bounding_box, 177 | **kwargs) 178 | if input_stream is None: 179 | assert(tracker.im_shape is None) 180 | else: 181 | assert(tracker.im_shape == input_stream.frame_shape) 182 | if pupil_bounding_box is None: 183 | test_pupil_bbox = et.default_bounding_box(tracker.im_shape) 184 | else: 185 | test_pupil_bbox = pupil_bounding_box 186 | if cr_bounding_box is None: 187 | test_cr_bbox = et.default_bounding_box(tracker.im_shape) 188 | else: 189 | test_cr_bbox = cr_bounding_box 190 | assert(np.all(tracker.pupil_bounding_box == test_pupil_bbox)) 191 | assert(np.all(tracker.cr_bounding_box == test_cr_bbox)) 192 | assert(input_stream == tracker.input_stream) 193 | if output_stream is None: 194 | assert(tracker.annotator.output_stream is None) 195 | else: 196 | assert(tracker.annotator.output_stream is not None) 197 | 198 | 199 | @pytest.mark.parametrize("pupil_params,recolor_cr", [ 200 | (np.array((np.nan, np.nan, np.nan, np.nan, np.nan)), True), 201 | (np.array((1, 2, 3, 4, 5)), False), 202 | (np.array((1, 2, 3, 4, 5)), True), 203 | 204 | ]) 205 | @patch("allensdk.eye_tracking.eye_tracking.ellipse_points", 206 | return_value=(5, 5)) 207 | def test_update_last_pupil_color(mock_ellipse_points, pupil_params, 208 | recolor_cr): 209 | input_stream = InputStream() 210 | tracker = et.EyeTracker(input_stream, recolor_cr=recolor_cr) 211 | shape = input_stream.frame_shape 212 | with patch.object(tracker, "blurred_image", 213 | MagicMock(return_value=np.zeros(shape))) as blur: 214 | with patch.object(tracker, "cr_filled_image", 215 | MagicMock(return_value=np.ones(shape))) as cr: 216 | tracker.update_last_pupil_color(pupil_params) 217 | if np.any(np.isnan(pupil_params)): 218 | assert cr.__getitem__.call_count == 0 219 | assert blur.__getitem__.call_count == 0 220 | assert mock_ellipse_points.call_count == 0 221 | elif recolor_cr: 222 | cr.__getitem__.assert_called_once() 223 | mock_ellipse_points.assert_called_once_with(pupil_params, 224 | ANY) 225 | assert blur.__getitem__.call_count == 0 226 | else: 227 | blur.__getitem__.assert_called_once() 228 | mock_ellipse_points.assert_called_once_with(pupil_params, 229 | ANY) 230 | assert cr.__getitem__.call_count == 0 231 | 232 | 233 | @pytest.mark.parametrize(("image,input_stream,output_stream," 234 | "starburst_params,ransac_params,pupil_bounding_box," 235 | "cr_bounding_box,kwargs"), [ 236 | (image(), 237 | None, 238 | None, 239 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 240 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 241 | "index_length": 100}, 242 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 243 | "number_of_close_points": 3}, 244 | None, 245 | None, 246 | {}), 247 | (image(), 248 | None, 249 | None, 250 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 251 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 252 | "index_length": 100}, 253 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 254 | "number_of_close_points": 3}, 255 | None, 256 | None, 257 | {"clip_pupil_values": False, "adaptive_pupil": False}), 258 | (image(), 259 | None, 260 | None, 261 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 262 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 263 | "index_length": 100}, 264 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 265 | "number_of_close_points": 3}, 266 | None, 267 | None, 268 | {"recolor_cr": False}), 269 | (image(cr_center=(85, 25), pupil_center=(70, 100)), 270 | None, 271 | None, 272 | {"n_rays": 20, "cr_threshold_factor": 0, "cr_threshold_pixels": 5, 273 | "pupil_threshold_factor": 0, "pupil_threshold_pixels": 5, 274 | "index_length": 100}, 275 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 276 | "number_of_close_points": 3}, 277 | None, 278 | None, 279 | {"recolor_cr": False, "adaptive_pupil": True, "clip_pupil_values": True}) 280 | ]) 281 | def test_process_image(image, input_stream, output_stream, 282 | starburst_params, ransac_params, pupil_bounding_box, 283 | cr_bounding_box, kwargs): 284 | tracker = et.EyeTracker(input_stream, output_stream, 285 | starburst_params, ransac_params, 286 | pupil_bounding_box, cr_bounding_box, **kwargs) 287 | with patch.object(tracker, "update_last_pupil_color") as mock_update: 288 | cr, pupil, cr_err, pupil_err = tracker.process_image(image) 289 | if not kwargs.get("adaptive_pupil", True): 290 | assert mock_update.call_count == 0 291 | 292 | 293 | @pytest.mark.parametrize(("input_stream,output_stream," 294 | "starburst_params,ransac_params,pupil_bounding_box," 295 | "cr_bounding_box,generate_QC_output,kwargs"), [ 296 | (InputStream(), 297 | None, 298 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 299 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 300 | "index_length": 100}, 301 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 302 | "number_of_close_points": 3}, 303 | None, 304 | None, 305 | False, 306 | {}), 307 | (InputStream(), 308 | OutputStream((200, 200)), 309 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 310 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 311 | "index_length": 100}, 312 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 313 | "number_of_close_points": 3}, 314 | None, 315 | None, 316 | True, 317 | {}), 318 | ]) 319 | def test_process_stream(input_stream, output_stream, starburst_params, 320 | ransac_params, pupil_bounding_box, 321 | cr_bounding_box, generate_QC_output, kwargs): 322 | tracker = et.EyeTracker(input_stream, output_stream, 323 | starburst_params, ransac_params, 324 | pupil_bounding_box, cr_bounding_box, 325 | generate_QC_output, **kwargs) 326 | cr, pupil, cr_err, pupil_err = tracker.process_stream(start=3) 327 | assert(pupil.shape == (input_stream.num_frames - 3, 5)) 328 | cr, pupil, cr_err, pupil_err = tracker.process_stream( 329 | update_mean_frame=False) 330 | assert(pupil.shape == (input_stream.num_frames, 5)) 331 | tracker.input_stream = InputStream(0) 332 | with patch.object(tracker, "process_image") as mock_process: 333 | cr, pupil, cr_err, pupil_err = tracker.process_stream() 334 | assert mock_process.call_count == 0 335 | 336 | 337 | @pytest.mark.parametrize(("input_stream,output_stream," 338 | "starburst_params,ransac_params,pupil_bounding_box," 339 | "cr_bounding_box,generate_QC_output,kwargs"), [ 340 | (InputStream(), 341 | None, 342 | {"n_rays": 20, "cr_threshold_factor": 1.4, "cr_threshold_pixels": 5, 343 | "pupil_threshold_factor": 1.4, "pupil_threshold_pixels": 5, 344 | "index_length": 100}, 345 | {"iterations": 20, "threshold": 1, "minimum_points_for_fit": 10, 346 | "number_of_close_points": 3}, 347 | None, 348 | None, 349 | False, 350 | {}), 351 | ]) 352 | def test_mean_frame(input_stream, output_stream, starburst_params, 353 | ransac_params, pupil_bounding_box, 354 | cr_bounding_box, generate_QC_output, kwargs): 355 | tracker = et.EyeTracker(input_stream, output_stream, 356 | starburst_params, ransac_params, 357 | pupil_bounding_box, cr_bounding_box, 358 | generate_QC_output, **kwargs) 359 | assert(tracker.mean_frame.shape == input_stream.frame_shape) 360 | -------------------------------------------------------------------------------- /test/test_feature_extraction.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import feature_extraction 2 | import numpy as np 3 | from skimage.draw import circle 4 | import pytest 5 | 6 | 7 | @pytest.fixture 8 | def image(): 9 | image = np.zeros((100, 100)) 10 | image[circle(30, 30, 10, (100, 100))] = 1 11 | image[circle(60, 60, 10, (100, 100))] = 1 12 | 13 | return image 14 | 15 | 16 | @pytest.mark.parametrize("radius", [ 17 | 5, 18 | 10 19 | ]) 20 | def test_get_circle_template(radius): 21 | mask = feature_extraction.get_circle_template(radius) 22 | assert(mask.shape == (2*radius+7, 2*radius+7)) 23 | mask = feature_extraction.get_circle_template(radius) 24 | assert(mask.shape == (2*radius+7, 2*radius+7)) 25 | 26 | 27 | @pytest.mark.parametrize("image,bounding_box", [ 28 | (image(), None), 29 | (image(), (10, 45, 10, 45)), 30 | (image(), (45, 75, 45, 75)) 31 | ]) 32 | def test_max_correlation_positions(image, bounding_box): 33 | kernel = feature_extraction.get_circle_template(8) 34 | y, x = feature_extraction.max_correlation_positions(image, kernel, 35 | bounding_box) 36 | if bounding_box is not None: 37 | assert(x > bounding_box[0] and x < bounding_box[1]) 38 | assert(y > bounding_box[2] and y < bounding_box[3]) 39 | -------------------------------------------------------------------------------- /test/test_fit_ellipse.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import fit_ellipse as fe 2 | import numpy as np 3 | import pytest 4 | from mock import patch 5 | 6 | 7 | def rotate_vector(y, x, theta): 8 | 9 | xp = x*np.cos(theta) - y*np.sin(theta) 10 | yp = x*np.sin(theta) + y*np.cos(theta) 11 | 12 | return yp, xp 13 | 14 | 15 | def ellipse_points(a, b, x0, y0, rotation): 16 | x = np.linspace(-a, a, 200) 17 | yp = np.sqrt(b**2 - (b**2 / a**2)*x**2) 18 | ym = -yp 19 | yp, x1 = rotate_vector(yp, x, rotation) 20 | ym, x2 = rotate_vector(ym, x, rotation) 21 | x = np.hstack((x1, x2)) + x0 22 | y = np.hstack((yp, ym)) + y0 23 | return np.vstack((y, x)).T 24 | 25 | 26 | @pytest.mark.parametrize("a,b,x0,y0,rotation", [ 27 | (3.0, 2.0, 20, 30, np.pi/6) 28 | ]) 29 | def test_ellipse_fit(a, b, x0, y0, rotation): 30 | data = ellipse_points(a, b, x0, y0, rotation) 31 | fitter = fe.EllipseFitter(40, 40, 0.01, 50) 32 | p, err = fitter.fit(data) 33 | x, y, angle, ax1, ax2 = p 34 | assert(np.abs(x - x0) < 0.0001) 35 | assert(np.abs(y - y0) < 0.0001) 36 | assert(np.abs(angle - np.degrees(rotation)) < 0.01) 37 | assert((np.abs(ax1-a) < 0.0001 and np.abs(ax2-b) < 0.0001) or 38 | (np.abs(ax1-b) < 0.0001 and np.abs(ax2-a) < 0.0001)) 39 | with patch.object(fitter._fitter, "fit", return_value=(None, None)): 40 | res, err = fitter.fit(data) 41 | assert(np.all(np.isnan(res))) 42 | results, err = fitter.fit(data, max_radius=min(a, b)) 43 | assert(np.all(np.isnan(results))) 44 | results, err = fitter.fit(data, max_eccentricity=-1) 45 | assert(np.all(np.isnan(results))) 46 | 47 | 48 | @pytest.mark.parametrize("point,ellipse_params,tolerance,result", [ 49 | ((0, 30.0), (0, 0, 0, 30.0, 5.0), 0.01, False), 50 | ((0, 30.0), (0, 0, 0, 30.5, 5.0), 0.01, True), 51 | ((7.0, 7.0), (5, 5, 45, np.sqrt(8.0), 9), 0.01, False), 52 | ((7.0, 7.0), (5, 5, 45, 9, 9), 0.01, True) 53 | ]) 54 | def test_not_on_ellipse(point, ellipse_params, tolerance, result): 55 | assert(fe.not_on_ellipse(point, ellipse_params, tolerance) == result) 56 | -------------------------------------------------------------------------------- /test/test_frame_stream.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import frame_stream as fs 2 | import sys 3 | import numpy as np 4 | import mock 5 | import pytest 6 | 7 | DEFAULT_FRAMES = 101 8 | DEFAULT_CV_FRAMES = 20 9 | # H264 is not available by default on windows 10 | if sys.platform == "win32": 11 | FOURCC = "FMP4" 12 | else: 13 | FOURCC = "H264" 14 | 15 | 16 | def image(shape=(200, 200), value=0): 17 | return np.ones(shape, dtype=np.uint8)*value*10 18 | 19 | 20 | @pytest.fixture(scope="module") 21 | def movie(tmpdir_factory): 22 | frame = image() 23 | filename = str(tmpdir_factory.mktemp("test").join('movie.avi')) 24 | ostream = fs.CvOutputStream(filename, frame.shape[::-1], is_color=False, 25 | fourcc=FOURCC) 26 | ostream.open(filename) 27 | for i in range(DEFAULT_CV_FRAMES): 28 | ostream.write(image(value=i)) 29 | ostream.close() 30 | return filename 31 | 32 | 33 | @pytest.fixture() 34 | def outfile(tmpdir_factory): 35 | return str(tmpdir_factory.mktemp("test").join("output.avi")) 36 | 37 | 38 | def test_frame_input_init(): 39 | istream = fs.FrameInputStream("test_path") 40 | assert(istream.movie_path == "test_path") 41 | assert(istream.num_frames == 0) 42 | with pytest.raises(NotImplementedError): 43 | istream.frame_shape 44 | 45 | 46 | def test_frame_input_slice_errors(): 47 | mock_cb = mock.MagicMock() 48 | istream = fs.FrameInputStream("test_path", num_frames=DEFAULT_FRAMES, 49 | process_frame_cb=mock_cb) 50 | with pytest.raises(NotImplementedError): 51 | istream._get_frame(20) 52 | with pytest.raises(NotImplementedError): 53 | istream._seek_frame(20) 54 | with pytest.raises(ValueError): 55 | for x in istream[6:10:0]: 56 | pass 57 | with pytest.raises(KeyError): 58 | istream["invalid"] 59 | with pytest.raises(IndexError): 60 | istream[DEFAULT_FRAMES] 61 | with pytest.raises(IndexError): 62 | istream[-DEFAULT_FRAMES-1] 63 | 64 | 65 | @pytest.mark.parametrize("start,stop,step", [ 66 | (5, 10, None), 67 | (2, 30, 2), 68 | (30, 2, -2), 69 | (None, -1, None), 70 | (3, None, None) 71 | ]) 72 | def test_frame_input_slice(start, stop, step): 73 | mock_cb = mock.MagicMock() 74 | istream = fs.FrameInputStream("test_path", num_frames=DEFAULT_FRAMES, 75 | process_frame_cb=mock_cb) 76 | with mock.patch.object(istream, "_get_frame", new=lambda a: a): 77 | count = 0 78 | for x in istream: 79 | count += 1 80 | assert count == DEFAULT_FRAMES 81 | assert mock_cb.call_count == DEFAULT_FRAMES 82 | with mock.patch.object(istream, "_seek_frame", new=lambda b: b): 83 | mock_cb.reset_mock() 84 | x = istream[5] 85 | mock_cb.assert_called_once_with(5) 86 | mock_cb.reset_mock() 87 | x = istream[-20] 88 | mock_cb.assert_called_once_with(DEFAULT_FRAMES-20) 89 | with mock.patch.object(istream, "_seek_frame", new=lambda b: b): 90 | mock_cb.reset_mock() 91 | for x in istream[start:stop:step]: 92 | pass 93 | rstop = stop if stop is not None else DEFAULT_FRAMES 94 | rstart = start if start is not None else 0 95 | rstep = step if step is not None else 1 96 | expected = [mock.call(x) for x in range(rstart, rstop, rstep)] 97 | mock_cb.assert_has_calls(expected) 98 | 99 | 100 | def test_frame_input_context_manager(): 101 | with mock.patch.object(fs.traceback, "print_tb") as mock_tb: 102 | with pytest.raises(OSError): 103 | with fs.FrameInputStream("test_path") as istream: 104 | assert(istream.movie_path == "test_path") 105 | raise OSError() 106 | mock_tb.assert_called_once() 107 | with fs.FrameInputStream("test_path") as istream: 108 | istream._num_frames = 10 109 | istream.frames_read = 10 110 | 111 | 112 | def test_frame_input_close(): 113 | with mock.patch.object(fs.logging, "debug") as mock_debug: 114 | istream = fs.FrameInputStream("test_path") 115 | istream.close() 116 | istream = fs.FrameInputStream("test_path", num_frames=10) 117 | istream.close() 118 | assert(mock_debug.call_count == 2) 119 | 120 | 121 | def test_cv_input_num_frames(movie): 122 | istream = fs.CvInputStream(movie) 123 | assert(istream.num_frames == DEFAULT_CV_FRAMES) 124 | assert(istream.num_frames == DEFAULT_CV_FRAMES) # using cached value 125 | 126 | 127 | def test_cv_input_frame_shape(movie): 128 | istream = fs.CvInputStream(movie) 129 | assert(istream.frame_shape == (200, 200)) 130 | assert(istream.frame_shape == (200, 200)) # using cached value 131 | 132 | 133 | def test_cv_input_open(movie): 134 | istream = fs.CvInputStream(movie) 135 | istream.open() 136 | with pytest.raises(IOError): 137 | istream.open() 138 | istream._error() 139 | assert(istream.cap is None) 140 | 141 | 142 | def test_cv_input_close(movie): 143 | istream = fs.CvInputStream(movie) 144 | istream.close() 145 | 146 | 147 | def test_cv_input_ioerrors(movie): 148 | istream = fs.CvInputStream(movie) 149 | with pytest.raises(IOError): 150 | istream._seek_frame(10) 151 | with pytest.raises(IOError): 152 | istream._get_frame(10) 153 | 154 | 155 | def test_cv_input_iter(movie): 156 | mock_cb = mock.MagicMock() 157 | istream = fs.CvInputStream(movie, process_frame_cb=mock_cb) 158 | count = 0 159 | for x in istream: 160 | count += 1 161 | assert count == DEFAULT_CV_FRAMES 162 | assert mock_cb.call_count == DEFAULT_CV_FRAMES 163 | mock_cb.reset_mock() 164 | for x in istream[5:10]: 165 | pass 166 | for i, x in enumerate(range(5, 10)): 167 | assert(np.all(np.abs(image(value=x) - 168 | mock_cb.mock_calls[i][1][0][:, :, 0]) < 2)) 169 | mock_cb.reset_mock() 170 | for x in istream[2:18:2]: 171 | pass 172 | for i, x in enumerate(range(2, 18, 2)): 173 | assert(np.all(np.abs(image(value=x) - 174 | mock_cb.mock_calls[i][1][0][:, :, 0]) < 2)) 175 | mock_cb.reset_mock() 176 | for x in istream[18:2:-2]: 177 | pass 178 | for i, x in enumerate(range(18, 2, -2)): 179 | assert(np.all(np.abs(image(value=x) - 180 | mock_cb.mock_calls[i][1][0][:, :, 0]) < 2)) 181 | 182 | 183 | def test_frame_output_init(): 184 | ostream = fs.FrameOutputStream(200) 185 | assert(ostream.frames_processed == 0) 186 | assert(ostream.block_size == 200) 187 | 188 | 189 | def test_frame_output_open(): 190 | ostream = fs.FrameOutputStream() 191 | ostream.frames_processed = 1 192 | ostream.open("test") 193 | assert(ostream.frames_processed == 0) 194 | assert(ostream.movie_path == "test") 195 | 196 | 197 | def test_frame_output_close(): 198 | with mock.patch.object(fs.FrameOutputStream, "_write_frames") as write: 199 | ostream = fs.FrameOutputStream() 200 | ostream.block_frames = [1, 2] 201 | ostream.close() 202 | write.assert_called_once_with([1, 2]) 203 | 204 | 205 | def test_frame_output_context_manager(): 206 | with mock.patch.object(fs.FrameOutputStream, "close") as mock_close: 207 | with pytest.raises(OSError): 208 | with fs.FrameOutputStream() as ostream: 209 | raise OSError() 210 | with fs.FrameOutputStream() as ostream: # noqa: F841 211 | pass 212 | mock_close.assert_called_once() 213 | 214 | 215 | def test_frame_output_write(): 216 | with pytest.raises(NotImplementedError): 217 | ostream = fs.FrameOutputStream() 218 | ostream.write(1) 219 | with mock.patch.object(fs.FrameOutputStream, "_write_frames") as write: 220 | ostream = fs.FrameOutputStream() 221 | ostream.write(1) 222 | write.assert_called_once() 223 | with mock.patch.object(fs.FrameOutputStream, "_write_frames") as write: 224 | ostream = fs.FrameOutputStream(block_size=50) 225 | ostream.write(1) 226 | ostream.close() 227 | write.assert_called_once() 228 | 229 | 230 | def test_cv_output_open(outfile): 231 | ostream = fs.CvOutputStream(outfile, (200, 200), fourcc=FOURCC) 232 | ostream.open(outfile) 233 | assert(ostream.movie_path == outfile) 234 | with pytest.raises(IOError): 235 | ostream.open(outfile) 236 | 237 | 238 | def test_cv_output_context_manager(outfile): 239 | with pytest.raises(IOError): 240 | with fs.CvOutputStream(outfile, (200, 200), fourcc=FOURCC) as ostream: 241 | pass 242 | with pytest.raises(IOError): 243 | with fs.CvOutputStream(outfile, (200, 200), fourcc=FOURCC) as ostream: 244 | ostream.open(outfile) 245 | ostream.open(outfile) 246 | 247 | 248 | def test_cv_output_write(outfile): 249 | ostream = fs.CvOutputStream(outfile, (200, 200), is_color=False, 250 | fourcc=FOURCC) 251 | ostream.write(image()) 252 | ostream.write(image()) 253 | ostream.close() 254 | check = fs.CvInputStream(outfile) 255 | assert(check.num_frames == 2) 256 | -------------------------------------------------------------------------------- /test/test_module.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import __main__ 2 | from allensdk.eye_tracking.frame_stream import CvOutputStream, CvInputStream 3 | import mock 4 | import numpy as np 5 | import ast 6 | import os 7 | import sys 8 | import json 9 | from skimage.draw import circle 10 | import pytest 11 | 12 | # H264 is not available by default on windows 13 | if sys.platform == "win32": 14 | FOURCC = "FMP4" 15 | else: 16 | FOURCC = "H264" 17 | 18 | 19 | def image(shape=(200, 200), cr_radius=10, cr_center=(100, 100), 20 | pupil_radius=30, pupil_center=(100, 100)): 21 | im = np.ones(shape, dtype=np.uint8)*128 22 | r, c = circle(pupil_center[0], pupil_center[1], pupil_radius, shape) 23 | im[r, c] = 0 24 | r, c = circle(cr_center[0], cr_center[1], cr_radius, shape) 25 | im[r, c] = 255 26 | return im 27 | 28 | 29 | def input_stream(source): 30 | mock_istream = mock.MagicMock() 31 | mock_istream.num_frames = 2 32 | mock_istream.frame_shape = (200, 200) 33 | mock_istream.__iter__ = mock.MagicMock( 34 | return_value=iter([np.zeros((200, 200)), np.zeros((200, 200))])) 35 | return mock_istream 36 | 37 | 38 | @pytest.fixture() 39 | def input_source(tmpdir_factory): 40 | filename = str(tmpdir_factory.mktemp("test").join('input.avi')) 41 | frame = image() 42 | ostream = CvOutputStream(filename, frame.shape[::-1], is_color=False, 43 | fourcc=FOURCC) 44 | ostream.open(filename) 45 | for i in range(10): 46 | ostream.write(frame) 47 | ostream.close() 48 | return filename 49 | 50 | 51 | @pytest.fixture() 52 | def input_json(tmpdir_factory): 53 | filename = str(tmpdir_factory.mktemp("test").join('input.json')) 54 | output_dir = str(tmpdir_factory.mktemp("test")) 55 | annotation_file = str(tmpdir_factory.mktemp("test").join('anno.avi')) 56 | in_json = {"starburst": {}, 57 | "ransac": {}, 58 | "eye_params": {}, 59 | "qc": { 60 | "generate_plots": False, 61 | "output_dir": output_dir}, 62 | "annotation": {"annotate_movie": False, 63 | "output_file": annotation_file, 64 | "fourcc": FOURCC}, 65 | "cr_bounding_box": [], 66 | "pupil_bounding_box": [], 67 | "output_dir": output_dir} 68 | with open(filename, "w") as f: 69 | json.dump(in_json, f, indent=1) 70 | return str(filename) 71 | 72 | 73 | def validate_dict(reference_dict, compare_dict): 74 | for k, v in reference_dict.items(): 75 | if isinstance(v, dict): 76 | validate_dict(v, compare_dict[k]) 77 | else: 78 | assert(compare_dict[k] == v) 79 | 80 | 81 | def assert_output(output_dir, annotation_file=None, qc_output_dir=None, 82 | output_json=None, input_data=None): 83 | cr = np.load(os.path.join(output_dir, "cr_params.npy")) 84 | pupil = np.load(os.path.join(output_dir, "pupil_params.npy")) 85 | assert(os.path.exists(os.path.join(output_dir, "mean_frame.png"))) 86 | assert(cr.shape == (10, 5)) 87 | assert(pupil.shape == (10, 5)) 88 | if annotation_file: 89 | check = CvInputStream(annotation_file) 90 | assert(check.num_frames == 10) 91 | check.close() 92 | if output_json: 93 | assert(os.path.exists(output_json)) 94 | if input_data: 95 | with open(output_json, "r") as f: 96 | output_data = json.load(f) 97 | validate_dict(input_data, output_data["input_parameters"]) 98 | if qc_output_dir: 99 | assert(os.path.exists(os.path.join(output_dir, "cr_all.png"))) 100 | 101 | 102 | def test_main_valid(input_source, input_json, tmpdir_factory): 103 | output_dir = str(tmpdir_factory.mktemp("output")) 104 | args = ["allensdk.eye_tracking", "--output_dir", output_dir, 105 | "--input_source", input_source] 106 | with mock.patch('sys.argv', args): 107 | __main__.main() 108 | assert_output(output_dir) 109 | 110 | 111 | @pytest.mark.parametrize("pupil_bbox_str,cr_bbox_str, adaptive_pupil", [ 112 | ("[20,50,40,70]", "[40,70,20,50]", True), 113 | ("[]", "[]", False) 114 | ]) 115 | def test_main_valid_json(input_source, input_json, pupil_bbox_str, cr_bbox_str, 116 | adaptive_pupil): 117 | args = ["allensdk.eye_tracking", "--input_json", input_json, 118 | "--input_source", input_source] 119 | with open(input_json, "r") as f: 120 | json_data = json.load(f) 121 | output_dir = json_data["output_dir"] 122 | with mock.patch('sys.argv', args): 123 | __main__.main() 124 | assert_output(output_dir) 125 | out_json = os.path.join(output_dir, "output.json") 126 | args.extend(["--qc.generate_plots", "True", 127 | "--annotation.annotate_movie", "True", 128 | "--output_json", out_json, 129 | "--pupil_bounding_box", pupil_bbox_str, 130 | "--cr_bounding_box", cr_bbox_str, 131 | "--eye_params.adaptive_pupil", str(adaptive_pupil)]) 132 | with mock.patch('sys.argv', args): 133 | __main__.main() 134 | json_data["eye_params"]["adaptive_pupil"] = adaptive_pupil 135 | json_data["qc"]["generate_plots"] = True 136 | json_data["annotation"]["annotate_movie"] = True 137 | json_data["output_json"] = out_json 138 | json_data["pupil_bounding_box"] = ast.literal_eval(pupil_bbox_str) 139 | json_data["cr_bounding_box"] = ast.literal_eval(cr_bbox_str) 140 | assert_output(output_dir, 141 | annotation_file=json_data["annotation"]["output_file"], 142 | qc_output_dir=json_data["qc"]["output_dir"], 143 | output_json=out_json, 144 | input_data=json_data) 145 | __main__.plt.close("all") 146 | 147 | 148 | def test_starburst_override(input_source, input_json): 149 | args = ["allensdk.eye_tracking", "--input_json", input_json, 150 | "--input_source", input_source] 151 | with open(input_json, "r") as f: 152 | json_data = json.load(f) 153 | output_dir = json_data["output_dir"] 154 | out_json = os.path.join(output_dir, "output.json") 155 | args.extend(["--starburst.cr_threshold_factor", "1.8", 156 | "--starburst.pupil_threshold_factor", "2.0", 157 | "--starburst.cr_threshold_pixels", "5", 158 | "--starburst.pupil_threshold_pixels", "30", 159 | "--output_json", out_json]) 160 | with mock.patch('sys.argv', args): 161 | __main__.main() 162 | json_data["starburst"]["cr_threshold_factor"] = 1.8 163 | json_data["starburst"]["pupil_threshold_factor"] = 2.0 164 | json_data["starburst"]["cr_threshold_pixels"] = 5 165 | json_data["starburst"]["pupil_threshold_pixels"] = 30 166 | json_data["output_json"] = out_json 167 | assert_output(output_dir, 168 | output_json=out_json, 169 | input_data=json_data) 170 | 171 | 172 | def test_main_invalid(): 173 | with mock.patch("sys.argv", ["allensdk.eye_tracking"]): 174 | with mock.patch("argparse.ArgumentParser.print_usage") as mock_print: 175 | __main__.main() 176 | mock_print.assert_called_once() 177 | -------------------------------------------------------------------------------- /test/test_plotting.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import plotting 2 | import mock 3 | import numpy as np 4 | import pytest 5 | 6 | 7 | class WriteEvaluator(object): 8 | def __init__(self, shape): 9 | self.shape = (shape[0], shape[1], 3) 10 | self.closed = False 11 | 12 | def write(self, array): 13 | assert(array.shape == self.shape) 14 | 15 | def close(self): 16 | self.closed = True 17 | 18 | 19 | def frame(height, width): 20 | return np.zeros((height, width), dtype=np.uint8) 21 | 22 | 23 | def test_get_rgb_frame(): 24 | img = frame(100, 100) 25 | with pytest.raises(ValueError): 26 | plotting.get_rgb_frame(np.dstack((img, img))) 27 | f = plotting.get_rgb_frame(img) 28 | assert f.shape == (100, 100, 3) 29 | f = plotting.get_rgb_frame(np.dstack((img, img, img))) 30 | assert f.shape == (100, 100, 3) 31 | 32 | 33 | @pytest.mark.parametrize("frame,pupil_params,cr_params", [ 34 | (frame(100, 100), 35 | np.array((40, 50, 45, 10, 8)), 36 | np.array((30, 60, 0, 5, 4))), 37 | (frame(100, 100), 38 | np.array((np.nan, np.nan, np.nan, np.nan, np.nan)), 39 | np.array((30, 60, 0, 5, 4))), 40 | (frame(100, 100), 41 | np.array((40, 50, 45, 10, 8)), 42 | np.array((np.nan, np.nan, np.nan, np.nan, np.nan))), 43 | ]) 44 | def test_annotate_frame(frame, pupil_params, cr_params): 45 | ostream = WriteEvaluator(frame.shape) 46 | annotator = plotting.Annotator(ostream) 47 | annotator.annotate_frame(frame, pupil_params, cr_params) 48 | annotator.close() 49 | assert(ostream.closed) 50 | 51 | annotator = plotting.Annotator() 52 | with mock.patch.object(annotator, "update_rc", 53 | mock.MagicMock(return_value=False)): 54 | with mock.patch("allensdk.eye_tracking.plotting." 55 | "color_by_points") as mock_color: 56 | annotator.annotate_frame(frame, pupil_params, cr_params) 57 | assert mock_color.call_count == 0 58 | annotator.annotate_frame(frame, pupil_params, cr_params) 59 | annotator.close() 60 | 61 | 62 | @pytest.mark.parametrize("frame,pupil_params,cr_params", [ 63 | (frame(100, 100), 64 | np.array((40, 50, 45, 10, 8)), 65 | np.array((30, 60, 0, 5, 4))), 66 | (frame(100, 100), 67 | np.array((np.nan, np.nan, np.nan, np.nan, np.nan)), 68 | np.array((30, 60, 0, 5, 4))), 69 | ]) 70 | def test_compute_density(frame, pupil_params, cr_params): 71 | ostream = WriteEvaluator(frame.shape) 72 | annotator = plotting.Annotator(ostream) 73 | assert(annotator._r["pupil"] is None) 74 | assert(annotator._c["pupil"] is None) 75 | assert(annotator._r["cr"] is None) 76 | assert(annotator._c["cr"] is None) 77 | assert(annotator.densities["pupil"] is None) 78 | assert(annotator.densities["cr"] is None) 79 | annotator.compute_density(frame, pupil_params, cr_params) 80 | if np.any(np.isnan(pupil_params)): 81 | assert(np.all(annotator.densities["pupil"] == 0)) 82 | else: 83 | assert(annotator.densities["pupil"].shape == frame.shape) 84 | 85 | 86 | @pytest.mark.parametrize("frame,pupil_params,cr_params", [ 87 | (frame(100, 100), 88 | np.array((40, 50, 45, 10, 8)), 89 | np.array((30, 60, 0, 5, 4))), 90 | (frame(100, 100), 91 | np.array((np.nan, np.nan, np.nan, np.nan, np.nan)), 92 | np.array((30, 60, 0, 5, 4))), 93 | ]) 94 | def test_annotate_with_cumulative(frame, pupil_params, cr_params): 95 | ostream = WriteEvaluator(frame.shape) 96 | annotator = plotting.Annotator(ostream) 97 | annotator.compute_density(frame, pupil_params, cr_params) 98 | 99 | with mock.patch.object(plotting.plt, "imsave") as mock_imsave: 100 | res = annotator.annotate_with_cumulative_pupil(frame, "pupil.png") 101 | mock_imsave.assert_called_with("pupil.png", mock.ANY) 102 | assert(res.shape == (frame.shape[0], frame.shape[1], 3)) 103 | res = annotator.annotate_with_cumulative_cr(frame, "cr.png") 104 | mock_imsave.assert_called_with("cr.png", mock.ANY) 105 | assert(res.shape == (frame.shape[0], frame.shape[1], 3)) 106 | 107 | with mock.patch("allensdk.eye_tracking.plotting." 108 | "color_by_mask") as mock_color: 109 | with mock.patch.object(plotting.plt, "imsave") as mock_imsave: 110 | res = plotting.annotate_with_cumulative(frame, None) 111 | assert(mock_color.call_count == 0) 112 | assert(mock_imsave.call_count == 0) 113 | 114 | 115 | def test_annotate_with_box(): 116 | img = frame(100, 100) 117 | bbox = np.array((20, 50, 40, 80), dtype="int") 118 | with mock.patch.object(plotting.plt, "imsave") as mock_imsave: 119 | res = plotting.annotate_with_box(img, bbox) 120 | assert(res.shape == (100, 100, 3)) 121 | assert(mock_imsave.call_count == 0) 122 | res = plotting.annotate_with_box(img, bbox, filename="test") 123 | assert(res.shape == (100, 100, 3)) 124 | mock_imsave.assert_called_once_with("test", res) 125 | 126 | 127 | @pytest.mark.parametrize("frame,pupil_params,cr_params,output_dir", [ 128 | (frame(100, 100), 129 | np.array((40, 50, 45, 10, 8)), 130 | np.array((30, 60, 0, 5, 4)), 131 | None), 132 | (frame(100, 100), 133 | np.array((np.nan, np.nan, np.nan, np.nan, np.nan)), 134 | np.array((30, 60, 0, 5, 4)), 135 | "test"), 136 | ]) 137 | def test_plot_cumulative(frame, pupil_params, cr_params, output_dir): 138 | ostream = WriteEvaluator(frame.shape) 139 | annotator = plotting.Annotator(ostream) 140 | annotator.compute_density(frame, pupil_params, cr_params) 141 | with mock.patch.object(plotting.plt, "show") as mock_show: 142 | with mock.patch.object(plotting.plt.Figure, "savefig") as mock_savefig: 143 | plotting.plot_cumulative(annotator.densities["pupil"], 144 | annotator.densities["cr"], 145 | output_dir=output_dir, 146 | show=False) 147 | plotting.plot_cumulative(annotator.densities["pupil"], 148 | annotator.densities["cr"], 149 | output_dir=output_dir, 150 | show=True) 151 | mock_show.assert_called_once() 152 | if output_dir is not None: 153 | assert(mock_savefig.call_count == 4) 154 | else: 155 | assert(mock_savefig.call_count == 0) 156 | plotting.plt.close("all") 157 | 158 | 159 | @pytest.mark.parametrize("pupil_params,cr_params,output_folder", [ 160 | (np.array((40, 50, 45, 10, 8)), 161 | np.array((30, 60, 0, 5, 4)), 162 | None), 163 | (np.array((40, 50, 45, 10, 8)), 164 | np.array((30, 60, 0, 5, 4)), 165 | "test"), 166 | ]) 167 | def test_plot_summary(pupil_params, cr_params, output_folder): 168 | with mock.patch.object(plotting.plt, "show") as mock_show: 169 | with mock.patch.object(plotting.plt.Figure, "savefig") as mock_savefig: 170 | plotting.plot_summary(pupil_params, cr_params, 171 | output_folder, False) 172 | plotting.plot_summary(pupil_params, cr_params, output_folder, True) 173 | mock_show.assert_called_once() 174 | if output_folder is not None: 175 | assert(mock_savefig.call_count == 12) 176 | else: 177 | assert(mock_savefig.call_count == 0) 178 | plotting.plt.close("all") 179 | 180 | 181 | def test_plots_no_title(): 182 | data = np.arange(50) 183 | img = frame(100, 100) 184 | with mock.patch.object(plotting.plt.Axes, "set_title") as mock_set: 185 | plotting.plot_timeseries(data, None) 186 | plotting.plot_density(img) 187 | assert mock_set.call_count == 0 188 | -------------------------------------------------------------------------------- /test/test_qt_ui.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking.ui import qt, __main__ 2 | from allensdk.eye_tracking._schemas import InputParameters 3 | from allensdk.eye_tracking.frame_stream import CvOutputStream 4 | import os 5 | import sys 6 | import numpy as np 7 | import json 8 | import pytest # noqa: F401 9 | from mock import patch, MagicMock 10 | 11 | DEFAULT_CV_FRAMES = 20 12 | # H264 is not available by default on windows 13 | if sys.platform == "win32": 14 | FOURCC = "FMP4" 15 | else: 16 | FOURCC = "H264" 17 | 18 | 19 | def image(shape=(200, 200), value=0): 20 | return np.ones(shape, dtype=np.uint8)*value*10 21 | 22 | 23 | @pytest.fixture(scope="module", params=[True, False]) 24 | def movie(tmpdir_factory, request): 25 | if not request.param: 26 | return "" 27 | frame = image() 28 | filename = str(tmpdir_factory.mktemp("test").join('movie.avi')) 29 | ostream = CvOutputStream(filename, frame.shape[::-1], is_color=False, 30 | fourcc=FOURCC) 31 | ostream.open(filename) 32 | for i in range(DEFAULT_CV_FRAMES): 33 | ostream.write(image(value=i)) 34 | ostream.close() 35 | return filename 36 | 37 | 38 | @pytest.fixture(scope="module", params=[True, False]) 39 | def config_file(tmpdir_factory, request): 40 | if request.param: 41 | filename = str(tmpdir_factory.mktemp("test").join('config.json')) 42 | config = {"input_json": {"input_source": {"read_only": True}}} 43 | with open(filename, "w") as f: 44 | json.dump(config, f) 45 | else: 46 | filename = "" 47 | return filename 48 | 49 | 50 | @pytest.fixture(params=[True, False]) 51 | def json_file(tmpdir_factory, request): 52 | if request.param: 53 | return str(tmpdir_factory.mktemp("test").join("file.json")) 54 | else: 55 | return "" 56 | 57 | 58 | @pytest.fixture(params=[True, False]) 59 | def mock_file_event(movie, request): 60 | mock_event = MagicMock() 61 | mock_data = MagicMock() 62 | mock_urls = MagicMock(return_value=[movie]) 63 | mock_urls.toLocalFile = MagicMock(return_value=movie) 64 | mock_data.hasUrls = MagicMock(return_value=request.param) 65 | mock_data.urls = MagicMock(return_value=mock_urls) 66 | mock_event.mimeData = MagicMock(return_value=mock_data) 67 | return mock_event 68 | 69 | 70 | @patch.object(qt.DropFileMixin, "file_dropped") 71 | def test_drop_file_mixin(mock_signal, qtbot, mock_file_event): 72 | w = qt.DropFileMixin() 73 | w.dragEnterEvent(mock_file_event) 74 | if mock_file_event.mimeData().hasUrls(): 75 | mock_file_event.accept.assert_called_once() 76 | else: 77 | mock_file_event.ignore.assert_called_once() 78 | mock_file_event.reset_mock() 79 | w.dragMoveEvent(mock_file_event) 80 | if mock_file_event.mimeData().hasUrls(): 81 | mock_file_event.accept.assert_called_once() 82 | else: 83 | mock_file_event.ignore.assert_called_once() 84 | mock_file_event.reset_mock() 85 | w.dropEvent(mock_file_event) 86 | if mock_file_event.mimeData().hasUrls(): 87 | mock_file_event.accept.assert_called_once() 88 | else: 89 | mock_file_event.ignore.assert_called_once() 90 | mock_file_event.reset_mock() 91 | 92 | 93 | def test_field_widget(qtbot): 94 | schema = InputParameters() 95 | w = qt.FieldWidget("test", schema.fields["output_dir"]) 96 | assert(w.key == "test") 97 | assert(w.field == schema.fields["output_dir"]) 98 | js = w.get_json() 99 | default = schema.fields["output_dir"].default 100 | assert(os.path.normpath(js) == os.path.normpath(default)) 101 | 102 | w = qt.FieldWidget("test2", schema.fields["input_source"]) 103 | assert(w.key == "test2") 104 | assert(str(w.text()) == "") 105 | js = w.get_json() 106 | assert(js is None) 107 | 108 | w = qt.FieldWidget("test3", schema.fields["pupil_bounding_box"]) 109 | assert(w.key == "test3") 110 | default = schema.fields["pupil_bounding_box"].default 111 | assert(str(w.text()) == str(default)) 112 | js = w.get_json() 113 | assert(js == default) 114 | 115 | 116 | def test_schema_widget(qtbot): 117 | schema = InputParameters() 118 | w = qt.SchemaWidget(None, schema, None) 119 | assert(w.key is None) 120 | assert(isinstance(w.fields["eye_params"], qt.SchemaWidget)) 121 | assert(isinstance(w.fields["input_source"], qt.FieldWidget)) 122 | 123 | js = w.get_json() 124 | assert(js) 125 | with patch.object(qt.FieldWidget, "get_json", return_value=None): 126 | js = w.get_json() 127 | assert(js is None) 128 | 129 | w.update_value("eye_params.min_pupil_value", "1000") 130 | val = str(w.fields["eye_params"].fields["min_pupil_value"].text()) 131 | assert(val == "1000") 132 | 133 | 134 | def test_input_json_widget(qtbot): 135 | schema = InputParameters() 136 | w = qt.InputJsonWidget(schema) 137 | assert(isinstance(w.schema_widget, qt.SchemaWidget)) 138 | w.update_value("eye_params.min_pupil_value", "1000") 139 | js = w.get_json() 140 | assert(js["eye_params"]["min_pupil_value"] == 1000) 141 | 142 | 143 | def test_bbox_canvas(qtbot): 144 | w = qt.BBoxCanvas(qt.Figure()) 145 | w.set_rgb(0, 0, 0) 146 | assert(w.rgba == (0, 0, 0, 20)) 147 | assert(not w.drawing) 148 | mock_wheel = MagicMock() 149 | w.wheelEvent(mock_wheel) 150 | mock_wheel.ignore.assert_called_once() 151 | w.show() 152 | w.move(0, 0) 153 | qtbot.addWidget(w) 154 | qtbot.mousePress(w, qt.QtCore.Qt.LeftButton, 155 | qt.QtCore.Qt.NoModifier, 156 | qt.QtCore.QPoint(0, 0)) 157 | assert(w.drawing) 158 | qtbot.mouseMove(w, qt.QtCore.QPoint(30, 30), 100) 159 | assert(w.drawing) 160 | with qtbot.waitSignal(w.box_updated) as box_updated: 161 | qtbot.mouseRelease(w, qt.QtCore.Qt.LeftButton, 162 | qt.QtCore.Qt.NoModifier, 163 | qt.QtCore.QPoint(50, 50)) 164 | assert(not w.drawing) 165 | assert(box_updated.signal_triggered) 166 | 167 | 168 | @patch.object(qt.BBoxCanvas, "width", return_value=200) 169 | @patch.object(qt.BBoxCanvas, "height", return_value=100) 170 | def test_bbox_canvas_scaling(h, w, qtbot): 171 | w = qt.BBoxCanvas(qt.Figure()) 172 | assert(w.im_shape == (100, 200)) 173 | w.im_shape = (100, 100) 174 | assert(w.im_shape == (100, 100)) 175 | s, x, y = w._scale_and_offset() 176 | assert(y == 0) 177 | assert(np.allclose(s, 1.0)) 178 | w.im_shape = (100, 400) 179 | assert(w.im_shape == (100, 400)) 180 | s, x, y = w._scale_and_offset() 181 | assert(x == 0) 182 | assert(np.allclose(s, 2.0)) 183 | 184 | 185 | @pytest.mark.parametrize("profile,config", [ 186 | (True, {"input_json": {"cr_bounding_box": {"visible": False}}}), 187 | (False, None) 188 | ]) 189 | @patch.object(qt.QtWidgets.QMessageBox, "exec_") 190 | def test_viewer_window(mock_exec, qtbot, movie, json_file, profile, config): 191 | schema_type = InputParameters 192 | w = qt.ViewerWindow(schema_type, profile_runs=profile, config=config) 193 | w.widget.json_view.update_value("cr_bounding_box", "[") 194 | qtbot.mouseClick(w.widget.rerun_button, qt.QtCore.Qt.LeftButton) 195 | mock_exec.assert_called_once() 196 | mock_exec.reset_mock() 197 | w.widget.save_json() 198 | mock_exec.assert_called_once() 199 | w.widget.json_view.update_value("cr_bounding_box", "[]") 200 | w.widget._setup_bbox() 201 | w.widget.update_bbox(100, 200, 100, 200) 202 | qtbot.mouseClick(w.widget.pupil_radio, qt.QtCore.Qt.LeftButton) 203 | w.widget.update_bbox(10, 50, 10, 50) 204 | assert(w.widget.get_json_data()["pupil_bounding_box"] == [10, 50, 10, 50]) 205 | qtbot.mouseClick(w.widget.cr_radio, qt.QtCore.Qt.LeftButton) 206 | w.widget.update_bbox(10, 50, 10, 50) 207 | assert(w.widget.get_json_data()["cr_bounding_box"] == [10, 50, 10, 50]) 208 | with patch.object(qt.QtWidgets.QFileDialog, "getSaveFileName", 209 | return_value=(json_file, None)): 210 | mock_exec.reset_mock() 211 | w.widget.save_json() 212 | mock_exec.assert_called_once() 213 | with patch.object(qt.QtWidgets.QFileDialog, "getOpenFileName", 214 | return_value=(movie, None)): 215 | w.widget.load_video() 216 | if movie: 217 | w.widget.save_json() 218 | w.widget.update_tracker() 219 | with patch.object(w.widget, "_parse_args", return_value={}): 220 | with patch.object(w.widget.tracker, "update_fit_parameters") as update: 221 | w.widget.update_tracker() 222 | assert(update.call_count == 0) 223 | if movie: 224 | w.widget.load_video(movie) 225 | 226 | 227 | @patch("allensdk.eye_tracking.ui.qt.QtWidgets.QApplication") 228 | def test_main(mock_app, qtbot, config_file): 229 | mock_app.exec_ = MagicMock(return_value=0) 230 | args = ["allensdk.eye_tracking_ui"] 231 | if config_file: 232 | args.extend(["--config_file", config_file]) 233 | with patch("sys.argv", args): 234 | with pytest.raises(SystemExit): 235 | __main__.main() 236 | 237 | 238 | @patch("allensdk.eye_tracking.ui.qt.QtWidgets.QApplication") 239 | def test_main_invalid(mock_app, qtbot, movie): 240 | mock_app.exec_ = MagicMock(return_value=0) 241 | args = ["allensdk.eye_tracking_ui"] 242 | if movie: 243 | args.extend(["--config_file", movie]) 244 | with patch("sys.argv", args): 245 | with pytest.raises(SystemExit): 246 | __main__.main() 247 | -------------------------------------------------------------------------------- /test/test_ransac.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import ransac 2 | import numpy as np 3 | import pytest 4 | 5 | 6 | def parameters(a, b): 7 | return np.array([b, a]) 8 | 9 | 10 | def line_data(a, b): 11 | d, _ = np.meshgrid(np.arange(1, 51), np.arange(2)) 12 | d = d.astype(float) 13 | d[1, :] *= b 14 | d[1, :] += a 15 | return d.T 16 | 17 | 18 | def poly_data(a, b, c): 19 | d, _ = np.meshgrid(np.arange(1, 51), np.arange(2)) 20 | d = d.astype(float) 21 | d[1, :] = a + b*d[0, :] + c*(d[0, :])**2 22 | return d.T 23 | 24 | 25 | def error_function(params, data): 26 | res = params[1] + params[0]*data[:, 0] 27 | return (data[:, 1]-res)**2 28 | 29 | 30 | def fit_function(data): 31 | params = np.polyfit(data[:, 0], data[:, 1], 1) 32 | error = np.mean(error_function(params, data)) 33 | return params, error 34 | 35 | 36 | @pytest.mark.parametrize("offset,slope,data_offset,threshold", [ 37 | (0, 1.5, 0, 1), 38 | (20, 0.5, 10, 1) 39 | ]) 40 | def test_check_outliers(offset, slope, data_offset, threshold): 41 | params = parameters(offset, slope) 42 | data = line_data(data_offset, slope) 43 | outlier_inds = np.arange(5, dtype=np.uint8) 44 | also_ins = ransac.check_outliers(error_function, params, data, 45 | outlier_inds, threshold) 46 | if (data_offset - offset)**2 > threshold: 47 | assert(len(also_ins) == 0) 48 | else: 49 | assert(len(also_ins) == 5) 50 | 51 | 52 | def test_partition_candidate_indices(): 53 | data = line_data(1, 1) 54 | ins, outs = ransac.partition_candidate_indices(data, 25) 55 | assert(len(ins) == len(outs) == 25) 56 | 57 | 58 | def test_fit(): 59 | data = line_data(0, 1) 60 | rf = ransac.RansacFitter() 61 | model, err = rf.fit(fit_function, error_function, data, 1.5, 20, 10, 10) 62 | assert(np.abs(model[0] - 1) < 0.0000001) 63 | assert(np.abs(model[1] - 0) < 0.0000001) 64 | with pytest.raises(ValueError): 65 | rf.fit(fit_function, error_function, data, 1.5, 200, 10, 10) 66 | data = poly_data(1, 2, 3) 67 | model, err = rf.fit(fit_function, error_function, data, 1.5, 20, 10, 10) 68 | assert(model is None) 69 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | from allensdk.eye_tracking import utils 2 | import numpy as np 3 | import pytest 4 | 5 | 6 | @pytest.fixture 7 | def image(): 8 | im, _ = np.meshgrid(np.arange(400), np.arange(400)) 9 | return im 10 | 11 | 12 | @pytest.mark.parametrize("index_length,n_rays", [ 13 | (200, 30), 14 | (100, 10) 15 | ]) 16 | def test_rotate_rays(index_length, n_rays): 17 | x = np.arange(index_length).reshape(1, index_length) 18 | y = np.zeros((1, index_length)) 19 | a = (np.arange(n_rays)*2.0*np.pi/n_rays).reshape(n_rays, 1) 20 | xr, yr = utils.rotate_rays(x, y, a) 21 | assert(xr.shape == yr.shape == (n_rays, index_length)) 22 | with pytest.raises(ValueError): 23 | a = a.reshape(1, n_rays) 24 | utils.rotate_rays(x, y, a) 25 | 26 | 27 | @pytest.mark.parametrize("index_length,n_rays", [ 28 | (200, 30), 29 | (100, 10) 30 | ]) 31 | def test_generate_ray_indices(index_length, n_rays): 32 | xr, yr = utils.generate_ray_indices(index_length, n_rays) 33 | assert(xr.shape == yr.shape == (n_rays, index_length)) 34 | 35 | 36 | @pytest.mark.parametrize("index_length,n_rays,image", [ 37 | (200, 20, image()), 38 | (600, 5, image()) 39 | ]) 40 | def test_get_ray_values(index_length, n_rays, image): 41 | x, y = utils.generate_ray_indices(index_length, n_rays) 42 | values = utils.get_ray_values(x, y, image) 43 | assert(len(values) == n_rays) 44 | -------------------------------------------------------------------------------- /test_requirements.txt: -------------------------------------------------------------------------------- 1 | coverage>=4.1 2 | pytest>=2.9.2 3 | mock 4 | pytest-qt --------------------------------------------------------------------------------