├── .github └── workflows │ └── test-package-conda.yml ├── .gitignore ├── .travis.yml ├── CHANGELOG.md ├── LICENSE.md ├── MANIFEST.in ├── Makefile ├── README.md ├── examples └── docker_submission │ ├── README.md │ ├── docker_deepgaze3 │ ├── Dockerfile │ ├── model_server.py │ └── requirements.txt │ ├── docker_pysaliency_model │ ├── Dockerfile │ ├── model_server.py │ ├── requirements.txt │ └── sample_submission.py │ └── sample_evaluation.py ├── notebooks ├── LSUN.ipynb └── Tutorial.ipynb ├── pyproject.toml ├── pysaliency ├── __init__.py ├── baseline_utils.py ├── dataset_config.py ├── datasets │ ├── __init__.py │ ├── fixations.py │ ├── scanpaths.py │ ├── stimuli.py │ └── utils.py ├── external_datasets │ ├── __init__.py │ ├── cat2000.py │ ├── coco_freeview.py │ ├── coco_search18.py │ ├── dut_omrom.py │ ├── figrim.py │ ├── isun.py │ ├── koehler.py │ ├── mit.py │ ├── nusef.py │ ├── osie.py │ ├── pascal_s.py │ ├── salicon.py │ ├── scripts │ │ ├── extract_fixations.m │ │ └── load_cat2000.m │ ├── toronto.py │ └── utils.py ├── external_models │ ├── __init__.py │ ├── deepgaze.py │ ├── matlab_models.py │ ├── models.py │ ├── scripts │ │ ├── AIM_wrapper.m │ │ ├── BMS │ │ │ ├── BMS_wrapper.m │ │ │ └── patches │ │ │ │ ├── adapt_opencv_paths.diff │ │ │ │ ├── correct_add_path.diff │ │ │ │ ├── fix_FileGettor.diff │ │ │ │ └── series │ │ ├── ContextAwareSaliency_wrapper.m │ │ ├── CovSal_wrapper.m │ │ ├── GBVS │ │ │ ├── GBVSIttiKoch_wrapper.m │ │ │ ├── GBVS_wrapper.m │ │ │ └── patches │ │ │ │ ├── get_path │ │ │ │ ├── make_mex_files_octave_compatible │ │ │ │ └── series │ │ ├── IttiKoch_wrapper.m │ │ ├── Judd │ │ │ ├── FaceDetect_patches │ │ │ │ ├── change_opencv_include │ │ │ │ └── series │ │ │ ├── JuddSaliencyModel_patches │ │ │ │ ├── find_cascade_file │ │ │ │ ├── locate_FelzenszwalbDetector_files │ │ │ │ └── series │ │ │ ├── Judd_wrapper.m │ │ │ ├── SaliencyToolbox_patches │ │ │ │ ├── enable_unit16 │ │ │ │ └── series │ │ │ └── voc_patches │ │ │ │ ├── change_fconv │ │ │ │ ├── matlabR2014a_compatible │ │ │ │ ├── matlabR2021a_compatible │ │ │ │ └── series │ │ ├── RARE2012_wrapper.m │ │ ├── SUN_wrapper.m │ │ └── ensure_image_is_color_image.m │ └── utils.py ├── filter_datasets.py ├── http_models.py ├── metric_optimization.py ├── metric_optimization_tf.py ├── metric_optimization_torch.py ├── metrics.py ├── models.py ├── numba_utils.py ├── optpy │ ├── README.md │ ├── __init__.py │ ├── jacobian.py │ └── optimization.py ├── plotting.py ├── precomputed_models.py ├── quilt.py ├── roc.py ├── roc_cython.pyx ├── saliency_map_conversion.py ├── saliency_map_conversion_theano.py ├── saliency_map_conversion_torch.py ├── saliency_map_models.py ├── sampling_models.py ├── tf_utils.py ├── theano_utils.py ├── torch_datasets.py ├── torch_utils.py └── utils │ ├── __init__.py │ └── variable_length_array.py ├── pytest.ini ├── requirements.txt ├── setup.py └── tests ├── conftest.py ├── datasets ├── test_datasets.py ├── test_fixations.py ├── test_scanpaths.py ├── test_stimuli.py └── utils.py ├── external_datasets ├── test_COCO_Search18.py ├── test_NUSEF.py ├── test_PASCAL_S.py ├── test_SALICON.py └── test_coco_freeview.py ├── external_models ├── AIM_color_stimulus.npy ├── AIM_grayscale_stimulus.npy ├── ContextAwareSaliency_color_stimulus.npy ├── ContextAwareSaliency_grayscale_stimulus.npy ├── CovSal_color_stimulus.npy ├── CovSal_grayscale_stimulus.npy ├── GBVSIttiKoch_color_stimulus.npy ├── GBVSIttiKoch_grayscale_stimulus.npy ├── GBVS_color_stimulus.npy ├── GBVS_grayscale_stimulus.npy ├── IttiKoch_color_stimulus.npy ├── IttiKoch_grayscale_stimulus.npy ├── Judd_color_stimulus.npy ├── Judd_grayscale_stimulus.npy ├── RARE2007_color_stimulus.npy ├── RARE2007_grayscale_stimulus.npy ├── RARE2012_color_stimulus.npy ├── RARE2012_grayscale_stimulus.npy ├── SUN_color_stimulus.npy ├── SUN_grayscale_stimulus.npy ├── color_stimulus.npy ├── grayscale_stimulus.npy └── test_deepgaze.py ├── skippedtest_theano_utils.py ├── test_baseline_utils.py ├── test_crossvalidation.py ├── test_dataset_config.py ├── test_external_datasets.py ├── test_external_models.py ├── test_filter_datasets.py ├── test_helpers.py ├── test_metric_optimization.py ├── test_metric_optimization_tf.py ├── test_metric_optimization_torch.py ├── test_models.py ├── test_numba_utils.py ├── test_precomputed_models.py ├── test_quilt.py ├── test_quilt ├── .pc │ ├── .quilt_patches │ ├── .quilt_series │ ├── .version │ ├── add_numbers.diff │ │ ├── .timestamp │ │ └── source.txt │ └── applied-patches ├── patches │ ├── add_numbers.diff │ └── series ├── source.txt ├── source │ ├── .pc │ │ ├── .quilt_patches │ │ ├── .quilt_series │ │ └── .version │ ├── patches │ └── source.txt └── target │ └── source.txt ├── test_saliency_map_conversion.py ├── test_saliency_map_conversion_theano.py ├── test_saliency_map_conversion_torch.py ├── test_saliency_map_conversion_torch_extended.py ├── test_saliency_map_models.py ├── test_sampling.py ├── test_torch_datasets.py ├── test_torch_utils.py ├── test_utils.py └── utils └── test_variable_length_array.py /.github/workflows/test-package-conda.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build-linux: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 5 10 | matrix: 11 | python-version: 12 | # - "3.7" # conda takes forever to install the dependencies 13 | - "3.8" 14 | - "3.9" 15 | - "3.10" 16 | - "3.11" 17 | steps: 18 | - uses: actions/checkout@v2 19 | - uses: conda-incubator/setup-miniconda@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | channels: conda-forge 23 | - name: Conda info 24 | # the shell setting is necessary for loading profile etc which activates the conda environment 25 | shell: bash -el {0} 26 | run: conda info 27 | - name: Install dependencies 28 | shell: bash -el {0} 29 | run: | 30 | conda install \ 31 | boltons \ 32 | cython \ 33 | deprecation \ 34 | dill \ 35 | diskcache \ 36 | h5py \ 37 | imageio \ 38 | natsort \ 39 | numba \ 40 | numpy \ 41 | numpydoc \ 42 | pandas \ 43 | piexif \ 44 | pillow \ 45 | pip \ 46 | pkg-config \ 47 | pytorch \ 48 | requests \ 49 | schema \ 50 | scikit-learn \ 51 | scipy \ 52 | setuptools \ 53 | sphinx \ 54 | torchvision \ 55 | tqdm 56 | - name: Conda list 57 | shell: bash -el {0} 58 | run: conda list 59 | - name: Test with pytest 60 | shell: bash -el {0} 61 | # hypothesis=6.113.0 is the last version with python 3.8 support 62 | run: | 63 | conda install pytest hypothesis=6.113.0 64 | python setup.py build_ext --inplace 65 | python -m pytest --nomatlab --notheano --nodownload tests 66 | - name: test build and install 67 | shell: bash -el {0} 68 | run: | 69 | python setup.py sdist 70 | pip install dist/*.tar.gz 71 | mkdir tmp && cd tmp && python -c "import pysaliency" 72 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .hypothesis 2 | .mypy_cache 3 | /tmp/ 4 | /build/ 5 | /pysaliency_datasets/ 6 | /test_datasets/ 7 | /test_models/ 8 | *.pyc 9 | *.swp 10 | *.c 11 | *.so 12 | *.egg-info 13 | .vscode 14 | .DS_Store 15 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | # - 2.7 4 | - 3.6 5 | - 3.7 6 | before_install: 7 | - sudo apt-get update 8 | - sudo apt-get install g++ 9 | - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then 10 | wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh; 11 | else 12 | wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 13 | fi 14 | - bash miniconda.sh -b -p $HOME/miniconda 15 | - source "$HOME/miniconda/etc/profile.d/conda.sh" 16 | - hash -r 17 | - conda config --set always_yes yes --set changeps1 no 18 | - conda config --add channels conda-forge 19 | - conda update -q conda 20 | # Useful for debugging any issues with conda 21 | - conda info -a 22 | # - wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 23 | # - chmod +x miniconda.sh 24 | # - ./miniconda.sh -b -p $HOME/miniconda 25 | # - export PATH=/home/travis/miniconda/bin:$PATH 26 | # - conda config --set always_yes yes --set changeps1 no 27 | # - conda config --add channels conda-forge 28 | # - conda update -q conda 29 | # - pip install tqdm 30 | install: 31 | - conda create -q -n test-env python=$TRAVIS_PYTHON_VERSION numpy scipy cython six setuptools sphinx numpydoc pkg-config pillow tqdm boltons natsort requests dill pytest theano imageio scikit-learn pandas pytorch hypothesis 32 | # - conda install -q -n test-env tensorflow=1.13.2 33 | - conda activate test-env 34 | before_script: 35 | - conda info 36 | - conda list 37 | - pip --version 38 | - pip freeze 39 | script: 40 | # Make sure the library installs. 41 | - python setup.py install 42 | - python setup.py build_ext --inplace 43 | - python -c "import pysaliency" # make sure can be installed without theano, pytorch, and tensorflow 44 | - conda install theano pytorch 45 | - pip install tensorflow==1.13.2 schema 46 | - python -m pytest --nomatlab tests 47 | - python setup.py sdist 48 | - pip install dist/*.tar.gz 49 | - mkdir tmp && cd tmp && python -c "import pysaliency" 50 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Matthias Kümmerer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.md 3 | include pysaliency/*.pyx 4 | include pysaliency/scripts/*.m 5 | include pysaliency/scripts/models/*.m 6 | include pysaliency/scripts/models/*/*.m 7 | include pysaliency/scripts/models/*/*/* 8 | include pysaliency/scripts/models/BMS/patches/* 9 | include pysaliency/scripts/models/GBVS/patches/* 10 | include pysaliency/scripts/models/Judd/patches/* 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | cython: 2 | #python setup.py build_ext --inplace 3 | python3 setup.py build_ext --inplace 4 | 5 | test: cython 6 | python3 -m pytest --nomatlab tests 7 | 8 | prepublish: 9 | ./run-docker.sh rm -rf dist 10 | ./run-docker.sh bash build.sh 11 | twine upload --repository=pysaliency-test dist/pysaliency*.tar.gz # assumes that ~/.pypirc defines a pysaliency-test entry, see https://test.pypi.org/manage/account/token/ 12 | # twine upload dist/pysaliency*.tar.gz -r testpypi 13 | 14 | 15 | publish: 16 | ./run-docker.sh rm -rf dist 17 | ./run-docker.sh bash build.sh 18 | twine upload --repository=pysaliency dist/pysaliency*.tar.gz 19 | 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pysaliency 2 | ========== 3 | 4 | ![test](https://github.com/matthias-k/pysaliency/actions/workflows/test-package-conda.yml/badge.svg) 5 | 6 | Pysaliency is a python package for saliency modelling. It aims at providing a unified interface 7 | to both the traditional saliency maps used in saliency modeling as well as probabilistic saliency 8 | models. 9 | 10 | Pysaliency can evaluate most commonly used saliency metrics, including AUC, sAUC, NSS, CC 11 | image-based KL divergence, fixation based KL divergence and SIM for saliency map models and 12 | log likelihoods and information gain for probabilistic models. 13 | 14 | Installation 15 | ------------ 16 | 17 | You can install pysaliency from pypi via 18 | 19 | pip install pysaliency 20 | 21 | 22 | Quickstart 23 | ---------- 24 | 25 | import pysaliency 26 | 27 | dataset_location = 'datasets' 28 | model_location = 'models' 29 | 30 | mit_stimuli, mit_fixations = pysaliency.external_datasets.get_mit1003(location=dataset_location) 31 | aim = pysaliency.AIM(location=model_location) 32 | saliency_map = aim.saliency_map(mit_stimuli.stimuli[0]) 33 | 34 | plt.imshow(saliency_map) 35 | 36 | 37 | auc = aim.AUC(mit_stimuli, mit_fixations) 38 | 39 | If you already have saliency maps for some dataset, you can import them into pysaliency easily: 40 | 41 | my_model = pysaliency.SaliencyMapModelFromDirectory(mit_stimuli, '/path/to/my/saliency_maps') 42 | auc = my_model.AUC(mit_stimuli, mit_fixations) 43 | 44 | Check out the [Tutorial](notebooks/Tutorial.ipynb) for a more detailed introduction! 45 | 46 | Included datasets and models 47 | ---------------------------- 48 | 49 | Pysaliency provides several important datasets: 50 | 51 | * MIT1003 52 | * MIT300 53 | * CAT2000 54 | * Toronto 55 | * Koehler 56 | * iSUN 57 | * SALICON (both the 2015 and the 2017 edition and each with both the original mouse traces and the inferred fixations) 58 | * FIGRIM 59 | * OSIE 60 | * NUSEF (the part with public images) 61 | 62 | and some influential models: 63 | * AIM 64 | * SUN 65 | * ContextAwareSaliency 66 | * BMS 67 | * GBVS 68 | * GBVSIttiKoch 69 | * Judd 70 | * IttiKoch 71 | * RARE2012 72 | * CovSal 73 | 74 | These models are using the original code which is often matlab. 75 | Therefore, a matlab licence is required to make use of these models, although quite some of them 76 | work with octave, too (see below). 77 | 78 | 79 | Using Octave 80 | ------------ 81 | 82 | pysaliency will fall back to octave if no matlab is installed. 83 | Some models might work with octave, e.g. AIM and GBVSIttiKoch. In Debian/Ubuntu you need to install 84 | `octave`, `octave-image`, `octave-statistics`, `liboctave-dev`. 85 | 86 | These models and dataset seem to work with octave: 87 | 88 | - models 89 | - AIM 90 | - GBVSIttiKoch 91 | - datasets 92 | - Toronto 93 | - MIT1003 94 | - MIT300 95 | - SALICON 96 | 97 | Dependencies 98 | ----------- 99 | 100 | The Judd Model needs some libraries to work. In ubuntu/debian you need to install these packages: 101 | `libopencv-core-dev, libopencv-flann-dev, libopencv-imgproc-dev, libopencv-photo-dev, libopencv-video-dev, libopencv-features2d-dev, libopencv-objdetect-dev, libopencv-calib3d-dev, libopencv-ml-dev, opencv2/contrib/contrib.hpp` 102 | -------------------------------------------------------------------------------- /examples/docker_submission/README.md: -------------------------------------------------------------------------------- 1 | # Submission 2 | 3 | This directory contains an example of how to create a Docker container for submitting a scanpath model to the MIT/Tübingen Saliency Benchmark. You'll need to build a docker or singularity container that offers a json API for requesting model predictions. The benchmark will use `pysaliency.http_models.HTTPScanpathModel` to interact with your model. 4 | 5 | ## Preparing the submission 6 | 7 | 1. Create a docker or singularity container that exposes your model as an API compatible with `pysaliency.http_models.HTTPScanpathModel`. There are two different examples contained here: 8 | - `docker_pysaliency`: A docker container exposing a pysaliency model (which is implemented in `sample_submission.py`). Use this if you already have a pysaliency implementation of your model. 9 | - `docker_deepgaze`: A docker container exposing the DeepGaze model. It demonstrates how to implement the API for an arbitrary model. 10 | 11 | 2. Build the Docker container as described in the "Launching the submission container" section. 12 | 13 | 14 | ## Launching the submission container 15 | 16 | In this example, we will use the `docker_pysaliency` directory to create a Docker container that exposes a pysaliency model. The container will run a Flask server that listens for HTTP requests and responds with model predictions. 17 | 18 | First we have to build the container 19 | ```bash 20 | docker build -t sample_pysaliency docker 21 | ``` 22 | 23 | Then we can start it 24 | ```bash 25 | docker run --rm -it -p 4000:4000 sample_pysaliency 26 | ``` 27 | The above command will launch the image as interactive container in the foregroun 28 | and expose the port `4000` to the host machine. 29 | If you prefer to run it in the background, use 30 | 31 | ```bash 32 | docker run --name sample_pysaliency -dp 4000:4000 sample_pysaliency 33 | ``` 34 | which will launch a container named `sample_pysaliency`. The container will be running in the background. 35 | 36 | To test the model server, run the sample_evaluation script. This script will evaluate the model on the MIT1003 dataset. Make sure to have the `pysaliency` package installed: 37 | ```bash 38 | python ./sample_evaluation.py 39 | ``` 40 | 41 | To delete the background container, run the following command: 42 | ```bash 43 | docker stop sample_pysaliency && docker rm sample_pysaliency 44 | ``` 45 | 46 | # TODOs 47 | 48 | - [ ] Establish and discuss how arguments can be passed to the model server, e.g. information about the image resolution in dva or other parameters. -------------------------------------------------------------------------------- /examples/docker_submission/docker_deepgaze3/Dockerfile: -------------------------------------------------------------------------------- 1 | # Specify a base image depending on the project. 2 | FROM bitnami/python:3.8 3 | # For more complex examples, might need to use a different base image. 4 | # FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-runtime 5 | 6 | WORKDIR /app 7 | 8 | ENV HTTP_PORT=4000 9 | 10 | RUN apt-get update \ 11 | && apt-get -y install gcc \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* /var/cache/apt/* 14 | 15 | COPY ./requirements.txt ./ 16 | RUN python -m pip install --no-cache -U pip \ 17 | && python -m pip install --no-cache -r requirements.txt 18 | 19 | COPY ./model_server.py ./ 20 | # COPY ./sample_submission.py ./ 21 | 22 | # This is needed for Singularity builds. 23 | EXPOSE $HTTP_PORT 24 | 25 | # The entrypoint for a container, 26 | CMD ["gunicorn", "-w", "1", "-b", "0.0.0.0:4000", "--pythonpath", ".", "--access-logfile", "-", "model_server:app"] -------------------------------------------------------------------------------- /examples/docker_submission/docker_deepgaze3/model_server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request 2 | import numpy as np 3 | import json 4 | from PIL import Image 5 | from io import BytesIO 6 | import orjson 7 | from scipy.ndimage import zoom 8 | from scipy.special import logsumexp 9 | import torch 10 | 11 | # Import your model here 12 | import deepgaze_pytorch 13 | 14 | # Flask server 15 | app = Flask("saliency-model-server") 16 | app.logger.setLevel("DEBUG") 17 | 18 | # # TODO - replace this with your model 19 | model = deepgaze_pytorch.DeepGazeIII(pretrained=True) 20 | 21 | def get_fixation_history(fixation_coordinates, model): 22 | history = [] 23 | for index in model.included_fixations: 24 | try: 25 | history.append(fixation_coordinates[index]) 26 | except IndexError: 27 | # for early fixations, not all previous fixations exist 28 | history.append(np.nan) 29 | return np.array(history) 30 | 31 | @app.route('/conditional_log_density', methods=['POST']) 32 | def conditional_log_density(): 33 | # get data 34 | data = json.loads(request.form['json_data']) 35 | 36 | # extract scanpath history 37 | x_hist = np.array(data['x_hist']) 38 | y_hist = np.array(data['y_hist']) 39 | print(x_hist) 40 | 41 | x_hist = get_fixation_history(x_hist, model) 42 | print(x_hist) 43 | y_hist = get_fixation_history(y_hist, model) 44 | # t_hist = np.array(data['t_hist']) 45 | # attributes = data.get('attributes', {}) 46 | 47 | # extract stimulus 48 | image_bytes = request.files['stimulus'].read() 49 | image = Image.open(BytesIO(image_bytes)) 50 | stimulus = np.array(image) 51 | 52 | # centerbias for deepgaze3 model 53 | centerbias_template = np.zeros((1024, 1024)) 54 | centerbias = zoom(centerbias_template, 55 | (stimulus.shape[0]/centerbias_template.shape[0], 56 | stimulus.shape[1]/centerbias_template.shape[1]), 57 | order=0, 58 | mode='nearest' 59 | ) 60 | centerbias -= logsumexp(centerbias) 61 | 62 | # make tensors for deepgaze3 model 63 | image_tensor = torch.tensor([stimulus.transpose(2, 0, 1)]) 64 | centerbias_tensor = torch.tensor([centerbias]) 65 | x_hist_tensor = torch.tensor([x_hist[model.included_fixations]]) 66 | y_hist_tensor = torch.tensor([y_hist[model.included_fixations]]) 67 | 68 | # return model response 69 | log_density = model(image_tensor, centerbias_tensor, x_hist_tensor, y_hist_tensor) 70 | log_density_list = log_density.tolist() 71 | response = orjson.dumps({'log_density': log_density_list}) 72 | return response 73 | 74 | 75 | @app.route('/type', methods=['GET']) 76 | def type(): 77 | type = "ScanpathModel" 78 | version = "v1.0.0" 79 | return orjson.dumps({'type': type, 'version': version}) 80 | 81 | 82 | 83 | 84 | def main(): 85 | app.run(host="localhost", port="4000", debug="True", threaded=True) 86 | 87 | 88 | if __name__ == "__main__": 89 | main() -------------------------------------------------------------------------------- /examples/docker_submission/docker_deepgaze3/requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | flask 3 | gunicorn 4 | numpy 5 | 6 | # Add additional dependencies here 7 | pysaliency 8 | scipy 9 | torch 10 | flask_orjson 11 | git+https://github.com/matthias-k/deepgaze -------------------------------------------------------------------------------- /examples/docker_submission/docker_pysaliency_model/Dockerfile: -------------------------------------------------------------------------------- 1 | # Specify a base image depending on the project. 2 | FROM bitnami/python:3.8 3 | # For more complex examples, might need to use a different base image. 4 | # FROM pytorch/pytorch:1.9.1-cuda11.1-cudnn8-runtime 5 | 6 | WORKDIR /app 7 | 8 | ENV HTTP_PORT=4000 9 | 10 | RUN apt-get update \ 11 | && apt-get -y install gcc \ 12 | && apt-get clean \ 13 | && rm -rf /var/lib/apt/lists/* /var/cache/apt/* 14 | 15 | COPY ./requirements.txt ./ 16 | RUN python -m pip install --no-cache -U pip \ 17 | && python -m pip install --no-cache -r requirements.txt 18 | 19 | COPY ./model_server.py ./ 20 | COPY ./sample_submission.py ./ 21 | 22 | # This is needed for Singularity builds. 23 | EXPOSE $HTTP_PORT 24 | 25 | # The entrypoint for a container, 26 | CMD ["gunicorn", "-w", "1", "-b", "0.0.0.0:4000", "--pythonpath", ".", "--access-logfile", "-", "model_server:app"] -------------------------------------------------------------------------------- /examples/docker_submission/docker_pysaliency_model/model_server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, jsonify 2 | from flask_orjson import OrjsonProvider 3 | import numpy as np 4 | import json 5 | from PIL import Image 6 | from io import BytesIO 7 | import orjson 8 | 9 | 10 | # Import your model here 11 | from sample_submission import MySimpleScanpathModel 12 | 13 | app = Flask("saliency-model-server") 14 | app.json_provider = OrjsonProvider(app) 15 | app.logger.setLevel("DEBUG") 16 | 17 | # # TODO - replace this with your model 18 | model = MySimpleScanpathModel() 19 | 20 | @app.route('/conditional_log_density', methods=['POST']) 21 | def conditional_log_density(): 22 | data = json.loads(request.form['json_data']) 23 | x_hist = np.array(data['x_hist']) 24 | y_hist = np.array(data['y_hist']) 25 | t_hist = np.array(data['t_hist']) 26 | attributes = data.get('attributes', {}) 27 | 28 | image_bytes = request.files['stimulus'].read() 29 | image = Image.open(BytesIO(image_bytes)) 30 | stimulus = np.array(image) 31 | 32 | log_density = model.conditional_log_density(stimulus, x_hist, y_hist, t_hist, attributes) 33 | log_density_list = log_density.tolist() 34 | response = orjson.dumps({'log_density': log_density_list}) 35 | return response 36 | 37 | 38 | @app.route('/type', methods=['GET']) 39 | def type(): 40 | type = "ScanpathModel" 41 | version = "v1.0.0" 42 | return orjson.dumps({'type': type, 'version': version}) 43 | 44 | 45 | def main(): 46 | app.run(host="localhost", port="4000", debug="True", threaded=True) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() -------------------------------------------------------------------------------- /examples/docker_submission/docker_pysaliency_model/requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | flask 3 | gunicorn 4 | numpy 5 | 6 | # Add additional dependencies here 7 | pysaliency 8 | scipy 9 | torch 10 | flask_orjson 11 | -------------------------------------------------------------------------------- /examples/docker_submission/docker_pysaliency_model/sample_submission.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file implements a very simple scanpath model using local constrast and a saccadic prior. 3 | """ 4 | 5 | import numpy as np 6 | import sys 7 | from typing import Union 8 | from scipy.ndimage import gaussian_filter 9 | import pysaliency 10 | 11 | 12 | class LocalContrastModel(pysaliency.Model): 13 | def __init__(self, bandwidth=0.05, **kwargs): 14 | super().__init__(**kwargs) 15 | self.bandwidth = bandwidth 16 | 17 | def _log_density(self, stimulus: Union[pysaliency.datasets.Stimulus, np.ndarray]): 18 | 19 | # _log_density can either take pysaliency Stimulus objects, or, for convenience, simply numpy arrays 20 | # `as_stimulus` ensures that we have a Stimulus object 21 | stimulus_object = pysaliency.datasets.as_stimulus(stimulus) 22 | 23 | # grayscale image 24 | gray_stimulus = np.mean(stimulus_object.stimulus_data, axis=2) 25 | 26 | # size contains the height and width of the image, but not potential color channels 27 | height, width = stimulus_object.size 28 | 29 | # define kernel size based on image size 30 | kernel_size = np.round(self.bandwidth * max(width, height)).astype(int) 31 | sigma = (kernel_size - 1) / 6 32 | 33 | # apply Gausian blur and calculate squared difference between blurred and original image 34 | blurred_stimulus = gaussian_filter(gray_stimulus, sigma) 35 | 36 | prediction = gaussian_filter((gray_stimulus - blurred_stimulus)**2, sigma) 37 | 38 | # normalize to [1, 255] 39 | prediction = (254 * (prediction / prediction.max())).astype(int) + 1 40 | 41 | density = prediction / prediction.sum() 42 | 43 | return np.log(density) 44 | 45 | class MySimpleScanpathModel(pysaliency.ScanpathModel): 46 | def __init__(self, spatial_model_bandwidth: float=0.05, saccade_width: float=0.1): 47 | self.spatial_model_bandwidth = spatial_model_bandwidth 48 | self.saccade_width = saccade_width 49 | self.spatial_model = LocalContrastModel(spatial_model_bandwidth) 50 | # self.spatial_model = pysaliency.UniformModel() 51 | 52 | 53 | def conditional_log_density(self, stimulus, x_hist, y_hist, t_hist, attributes=None, out=None,): 54 | stimulus_object = pysaliency.datasets.as_stimulus(stimulus) 55 | 56 | # size contains the height and width of the image, but not potential color channels 57 | height, width = stimulus_object.size 58 | 59 | spatial_prior_log_density = self.spatial_model.log_density(stimulus) 60 | spatial_prior_density = np.exp(spatial_prior_log_density) 61 | 62 | # compute saccade bias 63 | last_x = x_hist[-1] 64 | last_y = y_hist[-1] 65 | 66 | xs = np.arange(width, dtype=float) 67 | ys = np.arange(height, dtype=float) 68 | XS, YS = np.meshgrid(xs, ys) 69 | 70 | XS -= last_x 71 | YS -= last_y 72 | 73 | # compute prior 74 | max_size = max(width, height) 75 | actual_kernel_size = self.saccade_width * max_size 76 | 77 | saccade_bias = np.exp(-0.5 * (XS ** 2 + YS ** 2) / actual_kernel_size ** 2) 78 | 79 | prediction = spatial_prior_density * saccade_bias 80 | 81 | density = prediction / prediction.sum() 82 | return np.log(density) 83 | 84 | -------------------------------------------------------------------------------- /examples/docker_submission/sample_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | from pysaliency.http_models import HTTPScanpathModel 4 | sys.path.insert(0, '..') 5 | import pysaliency 6 | 7 | 8 | from tqdm import tqdm 9 | 10 | import deepgaze_pytorch 11 | 12 | def get_fixation_history(fixation_coordinates, model): 13 | history = [] 14 | for index in model.included_fixations: 15 | try: 16 | history.append(fixation_coordinates[index]) 17 | except IndexError: 18 | history.append(np.nan) 19 | return history 20 | 21 | if __name__ == "__main__": 22 | 23 | # initialize HTTPScanpathModel 24 | http_model = HTTPScanpathModel("http://localhost:4000") 25 | http_model.check_type() 26 | 27 | # for testing 28 | # test_model = deepgaze_pytorch.DeepGazeIII(pretrained=True) 29 | 30 | # get MIT1003 dataset 31 | stimuli, fixations = pysaliency.get_mit1003(location='pysaliency_datasets') 32 | 33 | # only use first 1000 fixations for testing 34 | eval_fixations = fixations[fixations.scanpath_history_length > 0][:1000] # error if no history 35 | 36 | 37 | information_gain = http_model.information_gain(stimuli, eval_fixations, average="image", verbose=True) 38 | print("IG:", information_gain) 39 | 40 | # for fixation_index in tqdm(range(len(eval_fixations))): 41 | 42 | # get server response for one stimulus 43 | # server_density = http_model.conditional_log_density( 44 | # stimulus=stimuli.stimuli[eval_fixations.n[fixation_index]], 45 | # x_hist=eval_fixations.x_hist[fixation_index], 46 | # y_hist=eval_fixations.y_hist[fixation_index], 47 | # t_hist=eval_fixations.t_hist[fixation_index] 48 | # ) 49 | # get test model response 50 | # test_model_density = test_model( 51 | # stimulus=stimuli.stimuli[eval_fixations.n[fixation_index]], 52 | # x_hist=eval_fixations.x_hist[fixation_index], 53 | # y_hist=eval_fixations.y_hist[fixation_index], 54 | # t_hist=eval_fixations.t_hist[fixation_index] 55 | # ) 56 | 57 | # Testing 58 | # test = np.testing.assert_allclose(server_density, test_model_density) 59 | -------------------------------------------------------------------------------- /notebooks/LSUN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# LSUN Challenge" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This document describes how to setup and run the python code for the LSUN saliency evaluation" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "## Setup" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "With your favorite python package management tool, install the needed libraries (here shown for `pip`):\n", 29 | "\n", 30 | " pip install numpy scipy theano Cython natsort dill hdf5storage\n", 31 | " pip install git+https://github.com/matthias-k/optpy\n", 32 | " pip install git+https://github.com/matthias-k/pysaliency\n", 33 | "\n", 34 | "If you want to use the SALICON dataset, you also need to install the\n", 35 | "[SALICON API](https://github.com/NUS-VIP/salicon-api)." 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Usage\n", 43 | "\n", 44 | "start by importing pysaliency:\n", 45 | "\n", 46 | " import pysaliency\n", 47 | "\n", 48 | "you probably also want to load the LSUN datasets:\n", 49 | "\n", 50 | " dataset_location = 'datasets' # where to cache datasets\n", 51 | " stimuli_salicon_train, fixations_salicon_train = pysaliency.get_SALICON_train(location=dataset_location)\n", 52 | " stimuli_salicon_val, fixations_salicon_val = pysaliency.get_SALICON_val(location=dataset_location)\n", 53 | " stimuli_salicon_test = pysaliency.get_SALICON_test(location=dataset_location)\n", 54 | " \n", 55 | " stimuli_isun_train, stimuli_isun_val, stimuli_isun_test, fixations_isun_train, fixations_isun_val = pysaliency.get_iSUN(location=dataset_location)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 4, 61 | "metadata": { 62 | "collapsed": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "# TODO: Add ModelFromDirectory for log densities\n", 67 | "# TODO: Change defaults for saliency map convertor (at least in LSUN subclass)\n", 68 | "# TODO: Write fit functions optimize_for_information_gain(model, stimuli, fixations)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "### Import your saliency model into pysaliency\n", 76 | "\n", 77 | "If you did not develop your model in the pysaliency framework, you have to import the generated saliencymaps or log-densities into pysaliency. If you have the saliency maps saved to an directory with names corresponding to the stimuli\n", 78 | "filenames, use `pysaliency.SaliencyMapModelFromDirectory`. You can save your saliency maps as png, jpg, tiff, mat or npy files." 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": { 85 | "collapsed": true 86 | }, 87 | "outputs": [], 88 | "source": [ 89 | "my_model = pysaliency.SaliencyMapModelFromDirectory(stimuli_salicon_train, \"my_model/saliency_maps/SALICON_TRAIN\")" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "If you have an LSUN submission file prepared, you can load it with `pysaliency.SaliencyMapModelFromDirectory`:" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "collapsed": true 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "my_model = pysaliency.SaliencyMapModelFromFile(stimuli_salicon_train, \"my_model/salicon_train.mat\")" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "### Evaluate your model" 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "Evaluating your model with pysaliency is fairly easy. In general, the evaluation functions take the stimuli and fixations to evaluate on, and maybe some additional configuration parameters. The following metrics are used in the LSUN saliency challenge (additionaly, the information gain metric is used, see below):" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "metadata": { 128 | "collapsed": true 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "my_model.AUC(stimuli_salicon_train, fixations_salicon_train, nonfixations='uniform')\n", 133 | "my_model.AUC(stimuli_salicon_train, fixations_salicon_train, nonfixations='shuffled')" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "### Optimize your model for information gain" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "If you wish to hand in a probabilistic model, you might wish to optimize the model for the nonlinearity and centerbiases\n", 148 | "of the datasets. Otherwise we will optimize all saliency map models for information gain using a subset of the iSUN dataset using the following code. Feel free to adapt it to your needs (for example, use more images for fitting)." 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "metadata": { 155 | "collapsed": true 156 | }, 157 | "outputs": [], 158 | "source": [ 159 | "my_probabilistic_model = pysaliency.SaliencyMapConvertor(my_model, ...)\n", 160 | "fit_stimuli, fit_fixations = pysaliency.create_subset(stimuli_isun_train, fixations_isun_train, range(0, 500))\n", 161 | "my_probabilistic_model = pysaliency.optimize_for_information_gain\n", 162 | " my_model, fit_stimuli, fit_fixations,\n", 163 | " num_nonlinearity=20,\n", 164 | " num_centerbias=12,\n", 165 | " optimize=[\n", 166 | " 'nonlinearity',\n", 167 | " 'centerbias',\n", 168 | " 'alpha',\n", 169 | " #'blurradius', # we do not optimize the bluring.\n", 170 | " ])" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": {}, 176 | "source": [ 177 | "### hand in your model" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": null, 183 | "metadata": { 184 | "collapsed": true 185 | }, 186 | "outputs": [], 187 | "source": [] 188 | } 189 | ], 190 | "metadata": { 191 | "kernelspec": { 192 | "display_name": "Python 3", 193 | "language": "python", 194 | "name": "python3" 195 | }, 196 | "language_info": { 197 | "codemirror_mode": { 198 | "name": "ipython", 199 | "version": 3 200 | }, 201 | "file_extension": ".py", 202 | "mimetype": "text/x-python", 203 | "name": "python", 204 | "nbconvert_exporter": "python", 205 | "pygments_lexer": "ipython3", 206 | "version": "3.5.1+" 207 | } 208 | }, 209 | "nbformat": 4, 210 | "nbformat_minor": 0 211 | } 212 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | select = ["B", "E", "F", "FIX", "I", "T20"] 3 | line-length = 200 4 | ignore = ["T201"] # ignore print statements 5 | 6 | [build-system] 7 | requires = [ 8 | "numpy", 9 | "setuptools", 10 | "wheel", 11 | "Cython" 12 | ] 13 | build-backend = "setuptools.build_meta" 14 | -------------------------------------------------------------------------------- /pysaliency/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | from . import datasets 4 | from . import saliency_map_models 5 | from . import models 6 | from . import external_models 7 | from . import external_datasets 8 | from . import utils 9 | 10 | from .datasets import ( 11 | Fixations, 12 | FixationTrains, 13 | ScanpathFixations, 14 | Scanpaths, 15 | Stimuli, 16 | FileStimuli, 17 | create_nonfixations, 18 | create_subset, 19 | remove_out_of_stimulus_fixations, 20 | concatenate_datasets, 21 | read_hdf5, 22 | ) 23 | from .dataset_config import load_dataset_from_config 24 | from .saliency_map_models import ( 25 | SaliencyMapModel, 26 | GeneralSaliencyMapModel, 27 | ScanpathSaliencyMapModel, 28 | GaussianSaliencyMapModel, 29 | FixationMap, 30 | CachedSaliencyMapModel, 31 | ExpSaliencyMapModel, 32 | DisjointUnionSaliencyMapModel, 33 | SubjectDependentSaliencyMapModel, 34 | StimulusDependentSaliencyMapModel, 35 | ResizingSaliencyMapModel, 36 | BluringSaliencyMapModel, 37 | DigitizeMapModel, 38 | HistogramNormalizedSaliencyMapModel, 39 | DensitySaliencyMapModel, 40 | LogDensitySaliencyMapModel, 41 | EqualizedSaliencyMapModel, 42 | WTASamplingMixin, 43 | ) 44 | from .sampling_models import SamplingModelMixin, ScanpathSamplingModelMixin 45 | from .models import ( 46 | ScanpathModel, 47 | GeneralModel, 48 | Model, 49 | UniformModel, 50 | CachedModel, 51 | MixtureModel, 52 | DisjointUnionModel, 53 | SubjectDependentModel, 54 | ShuffledAUCSaliencyMapModel, 55 | ResizingModel, 56 | ResizingScanpathModel, 57 | StimulusDependentModel, 58 | ) 59 | from .saliency_map_conversion import ( 60 | optimize_for_information_gain, 61 | ) 62 | from .precomputed_models import ( 63 | SaliencyMapModelFromFiles, 64 | SaliencyMapModelFromDirectory, 65 | SaliencyMapModelFromFile, 66 | ModelFromDirectory, 67 | HDF5SaliencyMapModel, 68 | HDF5Model, 69 | export_model_to_hdf5, 70 | ) 71 | 72 | from .external_models import ( 73 | AIM, 74 | SUN, 75 | ContextAwareSaliency, 76 | BMS, 77 | GBVS, 78 | GBVSIttiKoch, 79 | Judd, 80 | IttiKoch, 81 | RARE2012, 82 | CovSal, 83 | ) 84 | from .external_datasets import ( 85 | get_mit1003, 86 | get_mit1003_onesize, 87 | get_cat2000_train, 88 | get_cat2000_test, 89 | get_toronto, 90 | get_iSUN_training, 91 | get_iSUN_validation, 92 | get_iSUN_testing, 93 | get_SALICON, 94 | get_SALICON_train, 95 | get_SALICON_val, 96 | get_SALICON_test, 97 | get_mit300, 98 | get_koehler, 99 | get_FIGRIM, 100 | get_OSIE, 101 | get_NUSEF_public, 102 | ) 103 | 104 | from .metric_optimization import SIMSaliencyMapModel 105 | -------------------------------------------------------------------------------- /pysaliency/dataset_config.py: -------------------------------------------------------------------------------- 1 | from .datasets import read_hdf5, clip_out_of_stimulus_fixations, remove_out_of_stimulus_fixations 2 | from .filter_datasets import ( 3 | filter_fixations_by_number, 4 | filter_stimuli_by_number, 5 | filter_stimuli_by_size, 6 | train_split, 7 | validation_split, 8 | test_split, 9 | filter_scanpaths_by_attribute, 10 | filter_fixations_by_attribute, 11 | filter_stimuli_by_attribute, 12 | filter_scanpaths_by_length, 13 | remove_stimuli_without_fixations 14 | ) 15 | 16 | from schema import Schema, Optional 17 | 18 | 19 | dataset_config_schema = Schema({ 20 | 'stimuli': str, 21 | 'fixations': str, 22 | Optional('filters', default=[]): [{ 23 | 'type': str, 24 | Optional('parameters', default={}): dict, 25 | }], 26 | }) 27 | 28 | 29 | def load_dataset_from_config(config): 30 | config = dataset_config_schema.validate(config) 31 | stimuli = read_hdf5(config['stimuli']) 32 | fixations = read_hdf5(config['fixations']) 33 | 34 | for filter_config in config['filters']: 35 | stimuli, fixations = apply_dataset_filter_config(stimuli, fixations, filter_config) 36 | 37 | return stimuli, fixations 38 | 39 | 40 | def apply_dataset_filter_config(stimuli, fixations, filter_config): 41 | filter_dict = { 42 | 'filter_fixations_by_number': add_stimuli_argument(filter_fixations_by_number), 43 | 'filter_stimuli_by_number': filter_stimuli_by_number, 44 | 'filter_stimuli_by_size': filter_stimuli_by_size, 45 | 'clip_out_of_stimulus_fixations': _clip_out_of_stimulus_fixations, 46 | 'remove_out_of_stimulus_fixations': _remove_out_of_stimulus_fixations, 47 | 'train_split': train_split, 48 | 'validation_split': validation_split, 49 | 'test_split': test_split, 50 | 'filter_scanpaths_by_attribute': add_stimuli_argument(filter_scanpaths_by_attribute), 51 | 'filter_fixations_by_attribute': add_stimuli_argument(filter_fixations_by_attribute), 52 | 'filter_stimuli_by_attribute': filter_stimuli_by_attribute, 53 | 'filter_scanpaths_by_length': add_stimuli_argument(filter_scanpaths_by_length), 54 | 'remove_stimuli_without_fixations': remove_stimuli_without_fixations 55 | } 56 | 57 | if filter_config['type'] not in filter_dict: 58 | raise ValueError("Invalid filter name: {}".format(filter_config['type'])) 59 | 60 | filter_fn = filter_dict[filter_config['type']] 61 | 62 | return filter_fn(stimuli, fixations, **filter_config['parameters']) 63 | 64 | 65 | def _clip_out_of_stimulus_fixations(stimuli, fixations): 66 | clipped_fixations = clip_out_of_stimulus_fixations(fixations, stimuli=stimuli) 67 | return stimuli, clipped_fixations 68 | 69 | 70 | def _remove_out_of_stimulus_fixations(stimuli, fixations): 71 | filtered_fixations = remove_out_of_stimulus_fixations(stimuli, fixations) 72 | return stimuli, filtered_fixations 73 | 74 | 75 | def add_stimuli_argument(fn): 76 | def wrapped(stimuli, fixations, **kwargs): 77 | new_fixations = fn(fixations, **kwargs) 78 | return stimuli, new_fixations 79 | 80 | return wrapped 81 | -------------------------------------------------------------------------------- /pysaliency/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pathlib 4 | import warnings 5 | from collections.abc import Sequence 6 | from functools import wraps 7 | from hashlib import sha1 8 | from typing import Dict, List, Optional, Union 9 | from weakref import WeakValueDictionary 10 | 11 | import numpy as np 12 | from boltons.cacheutils import cached 13 | 14 | from ..utils.variable_length_array import VariableLengthArray, concatenate_variable_length_arrays 15 | 16 | try: 17 | from imageio.v3 import imread 18 | except ImportError: 19 | from imageio import imread 20 | from PIL import Image 21 | from tqdm import tqdm 22 | 23 | from ..utils import LazyList, remove_trailing_nans 24 | 25 | 26 | def hdf5_wrapper(mode=None): 27 | def decorator(f): 28 | @wraps(f) 29 | def wrapped(self, target, *args, **kwargs): 30 | if isinstance(target, (str, pathlib.Path)): 31 | import h5py 32 | with h5py.File(target, mode) as hdf5_file: 33 | return f(self, hdf5_file, *args, **kwargs) 34 | else: 35 | return f(self, target, *args, **kwargs) 36 | 37 | return wrapped 38 | return decorator 39 | 40 | 41 | def decode_string(data): 42 | if not isinstance(data, str): 43 | return data.decode('utf8') 44 | 45 | return data 46 | 47 | def create_hdf5_dataset(target, name, data): 48 | import h5py 49 | 50 | if isinstance(np.array(data).flatten()[0], str): 51 | data = np.array(data) 52 | original_shape = data.shape 53 | encoded_items = [decode_string(item).encode('utf8') for item in data.flatten()] 54 | encoded_data = np.array(encoded_items).reshape(original_shape) 55 | 56 | target.create_dataset( 57 | name, 58 | data=encoded_data, 59 | dtype=h5py.special_dtype(vlen=str) 60 | ) 61 | else: 62 | target.create_dataset(name, data=data) 63 | 64 | 65 | def get_merged_attribute_list(attributes): 66 | all_attributes = set(attributes[0]) 67 | common_attributes = set(attributes[0]) 68 | 69 | for _attributes in attributes[1:]: 70 | all_attributes = all_attributes.union(_attributes) 71 | common_attributes = common_attributes.intersection(_attributes) 72 | 73 | if common_attributes != all_attributes: 74 | lost_attributes = all_attributes.difference(common_attributes) 75 | warnings.warn(f"Discarding attributes which are not present everywhere: {lost_attributes}", stacklevel=4) 76 | 77 | return sorted(common_attributes) 78 | 79 | def _load_attribute_dict_from_hdf5(attribute_group): 80 | json_attributes = attribute_group.attrs['__attributes__'] 81 | if not isinstance(json_attributes, str): 82 | json_attributes = json_attributes.decode('utf8') 83 | __attributes__ = json.loads(json_attributes) 84 | 85 | attributes = {attribute: attribute_group[attribute][...] for attribute in __attributes__} 86 | return attributes 87 | 88 | 89 | def concatenate_attributes(attributes): 90 | attributes = list(attributes) 91 | 92 | if isinstance(attributes[0], VariableLengthArray): 93 | return concatenate_variable_length_arrays(attributes) 94 | 95 | attributes = [np.array(a) for a in attributes] 96 | for a in attributes: 97 | assert len(a.shape) == len(attributes[0].shape) 98 | 99 | if len(attributes[0].shape) == 1: 100 | return np.hstack(attributes) 101 | 102 | else: 103 | assert len(attributes[0].shape) == 2 104 | max_cols = max(a.shape[1] for a in attributes) 105 | padded_attributes = [] 106 | for a in attributes: 107 | if a.shape[1] < max_cols: 108 | padding = np.empty((a.shape[0], max_cols-a.shape[1]), dtype=a.dtype) 109 | padding[:] = np.nan 110 | padded_attributes.append(np.hstack((a, padding))) 111 | else: 112 | padded_attributes.append(a) 113 | return np.vstack(padded_attributes) -------------------------------------------------------------------------------- /pysaliency/external_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division 2 | 3 | import zipfile 4 | import os 5 | import glob 6 | 7 | import numpy as np 8 | from scipy.io import loadmat 9 | from tqdm import tqdm 10 | 11 | from ..datasets import FixationTrains 12 | from ..utils import ( 13 | download_and_check, 14 | atomic_directory_setup, 15 | ) 16 | 17 | from .utils import create_stimuli, _load 18 | 19 | from .toronto import get_toronto, get_toronto_with_subjects 20 | from .mit import get_mit1003, get_mit1003_with_initial_fixation, get_mit1003_onesize, get_mit300 21 | from .cat2000 import get_cat2000_test, get_cat2000_train 22 | from .isun import get_iSUN, get_iSUN_training, get_iSUN_validation, get_iSUN_testing 23 | from .salicon import get_SALICON, get_SALICON_train, get_SALICON_val, get_SALICON_test 24 | from .koehler import get_koehler 25 | from .figrim import get_FIGRIM 26 | from .osie import get_OSIE 27 | from .nusef import get_NUSEF_public 28 | from .pascal_s import get_PASCAL_S 29 | from .dut_omrom import get_DUT_OMRON 30 | from .coco_search18 import get_COCO_Search18, get_COCO_Search18_train, get_COCO_Search18_validation 31 | from .coco_freeview import get_COCO_Freeview, get_COCO_Freeview_train, get_COCO_Freeview_validation, get_COCO_Freeview_test 32 | -------------------------------------------------------------------------------- /pysaliency/external_datasets/dut_omrom.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import glob 4 | import os 5 | import zipfile 6 | from tempfile import TemporaryDirectory 7 | from typing import Tuple 8 | 9 | import numpy as np 10 | from scipy.io import loadmat 11 | from tqdm import tqdm 12 | 13 | from ..datasets import ScanpathFixations, Scanpaths, Stimuli 14 | from ..utils import ( 15 | atomic_directory_setup, 16 | download_and_check, 17 | ) 18 | from .utils import _load, create_stimuli 19 | 20 | 21 | def get_DUT_OMRON(location=None) -> Tuple[Stimuli, ScanpathFixations]: 22 | """ 23 | Loads or downloads the DUT-OMRON fixation dataset. 24 | The dataset consists of 5168 natural images with 25 | a maximal size of 400 pixel and eye movement data 26 | from 5 subjects in a 2 second free viewing task. 27 | The eye movement data has been filtered and preprocessed, 28 | see the dataset documentation for more details. 29 | 30 | Note that the dataset contains subject ids but they 31 | might not be consisten across images. 32 | 33 | @type location: string, defaults to `None` 34 | @param location: If and where to cache the dataset. The dataset 35 | will be stored in the subdirectory `DUT-OMRON` of 36 | location and read from there, if already present. 37 | @return: Stimuli, FixationTrains 38 | 39 | .. seealso:: 40 | 41 | Chuan Yang, Lihe Zhang, Huchuan Lu, Xiang Ruan, Minghsuan Yang. Saliency Detection Via Graph-Based Manifold Ranking, CVPR2013. 42 | 43 | http://saliencydetection.net/dut-omron 44 | """ 45 | if location: 46 | location = os.path.join(location, 'DUT-OMRON') 47 | if os.path.exists(location): 48 | stimuli = _load(os.path.join(location, 'stimuli.hdf5')) 49 | fixations = _load(os.path.join(location, 'fixations.hdf5')) 50 | return stimuli, fixations 51 | os.makedirs(location) 52 | 53 | n_fixations = 0 54 | 55 | with atomic_directory_setup(location): 56 | with TemporaryDirectory() as temp_dir: 57 | 58 | download_and_check('http://saliencydetection.net/dut-omron/download/DUT-OMRON-image.zip', 59 | os.path.join(temp_dir, 'DUT-OMRON-image.zip'), 60 | 'a8951db9297afacf78bc0e5079103cf1') 61 | 62 | download_and_check('http://saliencydetection.net/dut-omron/download/DUT-OMRON-eye-fixations.zip', 63 | os.path.join(temp_dir, 'DUT-OMRON-eye-fixations.zip'), 64 | 'd9f4f83fcc78b1e5efb579ae9fb0edc2') 65 | 66 | # Stimuli 67 | print('Creating stimuli') 68 | f = zipfile.ZipFile(os.path.join(temp_dir, 'DUT-OMRON-image.zip')) 69 | f.extractall(temp_dir) 70 | 71 | stimuli_src_location = os.path.join(temp_dir, 'DUT-OMRON-image') 72 | images = glob.glob(os.path.join(stimuli_src_location, '*.jpg')) 73 | images = [os.path.relpath(img, start=stimuli_src_location) for img in images] 74 | stimuli_filenames = sorted(images) 75 | 76 | stimuli_target_location = os.path.join(location, 'Stimuli') if location else None 77 | stimuli = create_stimuli(stimuli_src_location, stimuli_filenames, stimuli_target_location) 78 | 79 | stimuli_basenames = [os.path.basename(f) for f in stimuli_filenames] 80 | 81 | # FixationTrains 82 | 83 | print('Creating fixations') 84 | f = zipfile.ZipFile(os.path.join(temp_dir, 'DUT-OMRON-eye-fixations.zip')) 85 | f.extractall(temp_dir) 86 | 87 | train_xs = [] 88 | train_ys = [] 89 | train_ts = [] 90 | train_ns = [] 91 | train_subjects = [] 92 | 93 | for n, basename in enumerate(tqdm(stimuli_basenames)): 94 | eye_filename = os.path.join(temp_dir, 'DUT-OMRON-eye-fixations', 'mat', basename.replace('.jpg', '.mat')) 95 | eye_data = loadmat(eye_filename)['s'] 96 | xs, ys, subject_ids = eye_data.T 97 | n_fixations += len(xs) - 1 # first entry is image size 98 | for subject_index in range(subject_ids.max()): 99 | subject_inds = subject_ids == subject_index + 1 # subject==0 is image size 100 | 101 | if not np.any(subject_inds): 102 | continue 103 | 104 | # since there are coordinates with value 0, we assume they are 0-indexed (although they are matlab) 105 | train_xs.append(xs[subject_inds]) 106 | train_ys.append(ys[subject_inds]) 107 | train_ts.append(np.arange(subject_inds.sum())) 108 | train_ns.append(n) 109 | train_subjects.append(subject_index) 110 | 111 | fixations = ScanpathFixations(Scanpaths( 112 | xs=train_xs, 113 | ys=train_ys, 114 | ts=train_ts, 115 | n=train_ns, 116 | subject=train_subjects, 117 | )) 118 | 119 | assert len(fixations) == n_fixations 120 | 121 | if location: 122 | stimuli.to_hdf5(os.path.join(location, 'stimuli.hdf5')) 123 | fixations.to_hdf5(os.path.join(location, 'fixations.hdf5')) 124 | return stimuli, fixations 125 | -------------------------------------------------------------------------------- /pysaliency/external_datasets/koehler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import itertools 4 | import os 5 | import zipfile 6 | from tempfile import TemporaryDirectory 7 | 8 | import numpy as np 9 | from scipy.io import loadmat 10 | from tqdm import tqdm 11 | 12 | from ..datasets import FixationTrains 13 | from ..utils import ( 14 | atomic_directory_setup, 15 | check_file_hash, 16 | ) 17 | from .utils import _load, create_stimuli 18 | 19 | 20 | def _get_koehler_fixations(data, task, n_stimuli): 21 | tasks = {'freeviewing': 'freeview', 22 | 'objectsearch': 'objsearch', 23 | 'saliencysearch': 'salview'} 24 | task = tasks[task] 25 | 26 | data_x = data['{}_x'.format(task)] 27 | data_y = data['{}_y'.format(task)] 28 | 29 | # Load Fixation Data 30 | xs = [] 31 | ys = [] 32 | ts = [] 33 | ns = [] 34 | subjects = [] 35 | subject_ids = range(data_x.shape[0]) 36 | 37 | for n, subject_id in tqdm(list(itertools.product(range(n_stimuli), subject_ids))): 38 | x = data_x[subject_id, n, :] - 1 39 | y = data_y[subject_id, n, :] - 1 40 | inds = np.logical_not(np.isnan(x)) 41 | x = x[inds] 42 | y = y[inds] 43 | xs.append(x) 44 | ys.append(y) 45 | ts.append(range(len(x))) 46 | ns.append(n) 47 | subjects.append(subject_id) 48 | return FixationTrains.from_fixation_trains(xs, ys, ts, ns, subjects) 49 | 50 | 51 | def get_koehler(location=None, datafile=None): 52 | """ 53 | Loads or or extracts and caches the Koehler dataset. The dataset 54 | consists of 800 color images of outdoor and indoor scenes 55 | of size 405x405px and the fixations for three different tasks: 56 | free viewing, object search and saliency search. 57 | 58 | @type location: string, defaults to `None` 59 | @param location: If and where to cache the dataset. The dataset 60 | will be stored in the subdirectory `koehler` of 61 | location and read from there, if already present. 62 | @return: stimuli, fixations_freeviewing, fixations_objectsearch, fixations_saliencysearch 63 | 64 | .. note:: As this dataset is only after filling a download form, pysaliency 65 | cannot download it for you. Instead you have to download the file 66 | `PublicData.zip` and provide it to this function via the `datafile` 67 | keyword argument on the first call. 68 | 69 | .. seealso:: 70 | 71 | Kathryn Koehler, Fei Guo, Sheng Zhang, Miguel P. Eckstein. What Do Saliency Models Predict? [JoV 2014] 72 | 73 | http://www.journalofvision.org/content/14/3/14.full 74 | 75 | https://labs.psych.ucsb.edu/eckstein/miguel/research_pages/saliencydata.html 76 | """ 77 | if location: 78 | location = os.path.join(location, 'Koehler') 79 | if os.path.exists(location): 80 | stimuli = _load(os.path.join(location, 'stimuli.hdf5')) 81 | fixations_freeviewing = _load(os.path.join(location, 'fixations_freeviewing.hdf5')) 82 | fixations_objectsearch = _load(os.path.join(location, 'fixations_objectsearch.hdf5')) 83 | fixations_saliencysearch = _load(os.path.join(location, 'fixations_saliencysearch.hdf5')) 84 | return stimuli, fixations_freeviewing, fixations_objectsearch, fixations_saliencysearch 85 | #mkdir_p(location) 86 | if not datafile: 87 | raise ValueError('The Koehler dataset is not freely available! You have to ' 88 | 'request the data file from the authors and provide it to ' 89 | 'this function via the datafile argument') 90 | check_file_hash(datafile, '405f58aaa9b4ddc76f3e8f23c379d315') 91 | with atomic_directory_setup(location): 92 | with TemporaryDirectory() as temp_dir: 93 | z = zipfile.ZipFile(datafile) 94 | print('Extracting') 95 | z.extractall(temp_dir) 96 | 97 | # Stimuli 98 | stimuli_src_location = os.path.join(temp_dir, 'Images') 99 | stimuli_target_location = os.path.join(location, 'stimuli') if location else None 100 | stimuli_filenames = ['image_r_{}.jpg'.format(i) for i in range(1, 801)] 101 | 102 | stimuli = create_stimuli(stimuli_src_location, stimuli_filenames, stimuli_target_location) 103 | 104 | # Fixations 105 | 106 | data = loadmat(os.path.join(temp_dir, 'ObserverData.mat')) 107 | 108 | fs = [] 109 | 110 | for task in ['freeviewing', 'objectsearch', 'saliencysearch']: 111 | fs.append(_get_koehler_fixations(data, task, len(stimuli))) 112 | 113 | if location: 114 | stimuli.to_hdf5(os.path.join(location, 'stimuli.hdf5')) 115 | fs[0].to_hdf5(os.path.join(location, 'fixations_freeviewing.hdf5')) 116 | fs[1].to_hdf5(os.path.join(location, 'fixations_objectsearch.hdf5')) 117 | fs[2].to_hdf5(os.path.join(location, 'fixations_saliencysearch.hdf5')) 118 | return [stimuli] + fs -------------------------------------------------------------------------------- /pysaliency/external_datasets/osie.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import urllib 5 | from tempfile import TemporaryDirectory 6 | 7 | import numpy as np 8 | from boltons.fileutils import mkdir_p 9 | from scipy.io import loadmat 10 | from tqdm import tqdm 11 | 12 | from ..datasets import FixationTrains 13 | from ..utils import ( 14 | atomic_directory_setup, 15 | download_and_check, 16 | ) 17 | from .utils import _load, create_stimuli 18 | 19 | 20 | def get_OSIE(location=None): 21 | """ 22 | Loads or downloads and caches the OSIE dataset. The dataset 23 | consists of 700 images of size 800x600px 24 | and the fixations of 15 subjects while doing a 25 | freeviewing task with 3 seconds presentation time. 26 | 27 | @type location: string, defaults to `None` 28 | @param location: If and where to cache the dataset. The dataset 29 | will be stored in the subdirectory `toronto` of 30 | location and read from there, if already present. 31 | @return: Stimuli, FixationTrains 32 | 33 | .. seealso:: 34 | 35 | Juan Xu, Ming Jiang, Shuo Wang, Mohan Kankanhalli, Qi Zhao. Predicting Human Gaze Beyond Pixels [JoV 2014] 36 | 37 | http://www-users.cs.umn.edu/~qzhao/predicting.html 38 | """ 39 | if location: 40 | location = os.path.join(location, 'OSIE') 41 | if os.path.exists(location): 42 | stimuli = _load(os.path.join(location, 'stimuli.hdf5')) 43 | fixations = _load(os.path.join(location, 'fixations.hdf5')) 44 | return stimuli, fixations 45 | os.makedirs(location) 46 | with atomic_directory_setup(location): 47 | with TemporaryDirectory() as temp_dir: 48 | stimuli_src_location = os.path.join(temp_dir, 'stimuli') 49 | mkdir_p(stimuli_src_location) 50 | images = [] 51 | for i in tqdm(list(range(700))): 52 | filename = '{}.jpg'.format(i + 1001) 53 | target_name = os.path.join(stimuli_src_location, filename) 54 | urllib.request.urlretrieve( 55 | 'https://github.com/NUS-VIP/predicting-human-gaze-beyond-pixels/raw/master/data/stimuli/' + filename, 56 | target_name) 57 | images.append(filename) 58 | 59 | download_and_check('https://github.com/NUS-VIP/predicting-human-gaze-beyond-pixels/blob/master/data/eye/fixations.mat?raw=true', 60 | os.path.join(temp_dir, 'fixations.mat'), 61 | '8efdf6fe66f38b6e70f854c7ff45aa70') 62 | 63 | # Stimuli 64 | print('Creating stimuli') 65 | 66 | stimuli_target_location = os.path.join(location, 'Stimuli') if location else None 67 | stimuli = create_stimuli(stimuli_src_location, images, stimuli_target_location) 68 | 69 | stimulus_indices = {s: images.index(s) for s in images} 70 | 71 | # FixationTrains 72 | 73 | print('Creating fixations') 74 | data = loadmat(os.path.join(temp_dir, 'fixations.mat'))['fixations'].flatten() 75 | 76 | xs = [] 77 | ys = [] 78 | ts = [] 79 | ns = [] 80 | train_subjects = [] 81 | 82 | for stimulus_data in data: 83 | stimulus_data = stimulus_data[0, 0] 84 | n = stimulus_indices[stimulus_data['img'][0]] 85 | for subject, subject_data in enumerate(stimulus_data['subjects'].flatten()): 86 | fixations = subject_data[0, 0] 87 | if not len(fixations['fix_x'].flatten()): 88 | continue 89 | 90 | xs.append(fixations['fix_x'].flatten()) 91 | ys.append(fixations['fix_y'].flatten()) 92 | ts.append(np.arange(len(xs[-1]))) 93 | ns.append(n) 94 | train_subjects.append(subject) 95 | 96 | fixations = FixationTrains.from_fixation_trains(xs, ys, ts, ns, train_subjects) 97 | 98 | if location: 99 | stimuli.to_hdf5(os.path.join(location, 'stimuli.hdf5')) 100 | fixations.to_hdf5(os.path.join(location, 'fixations.hdf5')) 101 | return stimuli, fixations -------------------------------------------------------------------------------- /pysaliency/external_datasets/pascal_s.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import zipfile 5 | from tempfile import TemporaryDirectory 6 | 7 | import numpy as np 8 | import requests 9 | 10 | from ..datasets import ScanpathFixations, Scanpaths 11 | from ..utils import ( 12 | atomic_directory_setup, 13 | download_and_check, 14 | ) 15 | from .utils import _load, create_stimuli 16 | 17 | 18 | def get_PASCAL_S(location=None): 19 | """ 20 | Loads or downloads and caches the PASCAL-S dataset. 21 | The dataset consists of 850 images from the PASCAL-VOC 22 | 2010 validation set with fixation from 12 subjects 23 | during 2s free-viewing. 24 | 25 | Note that here only the eye movement data from PASCAL-S 26 | is included. The original dataset also provides 27 | salient object segmentation data. 28 | 29 | @type location: string, defaults to `None` 30 | @param location: If and where to cache the dataset. The dataset 31 | will be stored in the subdirectory `PASCAL-S` of 32 | location and read from there, if already present. 33 | @return: Stimuli, FixationTrains 34 | 35 | .. seealso:: 36 | 37 | Yin Li, Xiaodi Hu, Christof Koch, James M. Rehg, Alan L. Yuille: 38 | The Secrets of Salient Object Segmentation. CVPR 2014. 39 | 40 | http://cbs.ic.gatech.edu/salobj/ 41 | """ 42 | if location: 43 | location = os.path.join(location, 'PASCAL-S') 44 | if os.path.exists(location): 45 | stimuli = _load(os.path.join(location, 'stimuli.hdf5')) 46 | fixations = _load(os.path.join(location, 'fixations.hdf5')) 47 | return stimuli, fixations 48 | os.makedirs(location) 49 | 50 | n_stimuli = 850 51 | 52 | with atomic_directory_setup(location): 53 | with TemporaryDirectory() as temp_dir: 54 | 55 | try: 56 | download_and_check('http://cbs.ic.gatech.edu/salobj/download/salObj.zip', 57 | os.path.join(temp_dir, 'salObj.zip'), 58 | 'e48b4e5deac08bddcaec55ce56e4d420') 59 | except requests.exceptions.SSLError: 60 | print("http://cbs.ic.gatech.edu/salobj/download/salObj.zip seems to be using an invalid SSL certificate. Since this is known and since we're checking the MD5 sum in addition, we'll ignore the invalid certificate.") 61 | download_and_check('http://cbs.ic.gatech.edu/salobj/download/salObj.zip', 62 | os.path.join(temp_dir, 'salObj.zip'), 63 | 'e48b4e5deac08bddcaec55ce56e4d420', 64 | verify_ssl=False) 65 | 66 | # Stimuli 67 | print('Creating stimuli') 68 | f = zipfile.ZipFile(os.path.join(temp_dir, 'salObj.zip')) 69 | f.extractall(temp_dir) 70 | 71 | stimuli_src_location = os.path.join(temp_dir, 'datasets', 'imgs', 'pascal') 72 | stimuli_filenames = ['{}.jpg'.format(i + 1) for i in range(n_stimuli)] 73 | 74 | stimuli_target_location = os.path.join(location, 'Stimuli') if location else None 75 | stimuli = create_stimuli(stimuli_src_location, stimuli_filenames, stimuli_target_location) 76 | 77 | print('Creating fixations') 78 | 79 | train_xs = [] 80 | train_ys = [] 81 | train_ts = [] 82 | train_ns = [] 83 | train_subjects = [] 84 | 85 | import h5py # we don't import globally to avoid depending on h5py 86 | with h5py.File(os.path.join(temp_dir, 'datasets', 'fixations', 'pascalFix.mat'), mode='r') as hdf5_file: 87 | fixation_data = [hdf5_file[hdf5_file['fixCell'][0, stimulus_index]][:] for stimulus_index in range(n_stimuli)] 88 | 89 | for n in range(n_stimuli): 90 | ys, xs, subject_ids = fixation_data[n] 91 | for subject in sorted(set(subject_ids)): 92 | subject_inds = subject_ids == subject 93 | if not np.any(subject_inds): 94 | continue 95 | 96 | train_xs.append(xs[subject_inds] - 1) # data is 1-indexed in matlab 97 | train_ys.append(ys[subject_inds] - 1) 98 | train_ts.append(np.arange(subject_inds.sum())) 99 | train_ns.append(n) 100 | train_subjects.append(subject - 1) # subjects are 1-indexed in matlab 101 | 102 | fixations = ScanpathFixations(Scanpaths( 103 | xs=train_xs, 104 | ys=train_ys, 105 | ts=train_ts, 106 | n=train_ns, 107 | subject=train_subjects, 108 | )) 109 | 110 | if location: 111 | stimuli.to_hdf5(os.path.join(location, 'stimuli.hdf5')) 112 | fixations.to_hdf5(os.path.join(location, 'fixations.hdf5')) 113 | return stimuli, fixations 114 | -------------------------------------------------------------------------------- /pysaliency/external_datasets/scripts/extract_fixations.m: -------------------------------------------------------------------------------- 1 | function [ ] = extract_fixations(filename, datafolder, outname) 2 | % fprintf('Loading %s %s\n', datafolder, filename); 3 | addpath('DatabaseCode') 4 | datafile = strcat(filename(1:end-4), 'mat'); 5 | load(fullfile(datafolder, datafile)); 6 | variable_name = datafile(1:end-4); 7 | % matlab's variable name have a maximum length of 63 8 | variable_name = variable_name(1:min(end,63)); 9 | stimFile = eval([variable_name]); 10 | eyeData = stimFile.DATA(1).eyeData; 11 | [eyeData Fix Sac] = checkFixations(eyeData); 12 | fixs = find(eyeData(:,3)==0); % these are the indices of the fixations in the eyeData for a given image and user 13 | fixations = Fix.medianXY; 14 | starts = Fix.start; 15 | durations = Fix.duration; 16 | save(outname, 'fixations', 'starts', 'durations', '-v6') % version 6 makes octave files compatible with scipy 17 | -------------------------------------------------------------------------------- /pysaliency/external_datasets/scripts/load_cat2000.m: -------------------------------------------------------------------------------- 1 | load trainSet/allFixData.mat; 2 | ind = 0 3 | ks=keys(allData); 4 | for i=1:size(ks,2); 5 | k=ks(i); 6 | tmp=allData(char(k)); 7 | for j=1:size(tmp,1); 8 | ttmp=cell2mat(tmp(j)); 9 | name = ttmp.name; 10 | data = ttmp.data; 11 | filename = sprintf('extracted/fix%d_%d.mat', i, j); 12 | save(filename, 'name', 'data'); 13 | ind = ind+1; 14 | disp(sprintf('%d/%d', i,j)); 15 | end; 16 | end 17 | -------------------------------------------------------------------------------- /pysaliency/external_datasets/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division 2 | 3 | import os 4 | import shutil 5 | import warnings 6 | 7 | from ..datasets import FileStimuli, Stimuli, read_hdf5 8 | 9 | 10 | def create_memory_stimuli(filenames, attributes=None): 11 | """ 12 | Create a `Stimuli`-class from a list of filenames by reading the them 13 | """ 14 | tmp_stimuli = FileStimuli(filenames) 15 | stimuli = list(tmp_stimuli.stimuli) # Read all stimuli 16 | return Stimuli(stimuli, attributes=attributes) 17 | 18 | 19 | def create_stimuli(stimuli_location, filenames, location=None, attributes=None): 20 | """ 21 | Create a Stimuli class of stimuli. 22 | 23 | Parameters 24 | ---------- 25 | 26 | @type stimuli_location: string 27 | @param stimuli_location: the base path where the stimuli are located. 28 | If `location` is provided, this directory will 29 | be copied to `location` (see below). 30 | 31 | @type filenames: list of strings 32 | @param filenames: lists the filenames of the stimuli to include in the dataset. 33 | Filenames have to be relative to `stimuli_location`. 34 | 35 | @type location: string or `None` 36 | @param location: If provided, the function will copy the filenames to 37 | `location` and return a `FileStimuli`-object for the 38 | copied files. Otherwise a `Stimuli`-object is returned. 39 | 40 | @returns: `Stimuli` or `FileStimuli` object depending on `location`. 41 | 42 | """ 43 | if location is not None: 44 | shutil.copytree(stimuli_location, 45 | location) 46 | filenames = [os.path.join(location, f) for f in filenames] 47 | 48 | return FileStimuli(filenames, attributes=attributes) 49 | 50 | else: 51 | filenames = [os.path.join(stimuli_location, f) for f in filenames] 52 | return create_memory_stimuli(filenames, attributes=attributes) 53 | 54 | 55 | def _load(filename): 56 | """attempt to load hdf5 file and fallback to pickle files if present""" 57 | if os.path.isfile(filename): 58 | return read_hdf5(filename) 59 | 60 | stem, ext = os.path.splitext(filename) 61 | pydat_filename = stem + '.pydat' 62 | 63 | if os.path.isfile(pydat_filename): 64 | import dill 65 | # raise deprecation warning 66 | warnings.warn("Using pickle files is deprecated. Please convert to hdf5 files instead. Pickle support will be removed in pysaliency 0.4", DeprecationWarning, stacklevel=2) 67 | return dill.load(open(pydat_filename, 'rb')) 68 | else: 69 | raise FileNotFoundError(f"Neither {filename} nor {pydat_filename} exist.") -------------------------------------------------------------------------------- /pysaliency/external_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepgaze import ( 2 | DeepGazeI, 3 | DeepGazeIIE, 4 | ) 5 | from .matlab_models import ( 6 | AIM, 7 | BMS, 8 | GBVS, 9 | RARE2007, 10 | RARE2012, 11 | SUN, 12 | ContextAwareSaliency, 13 | CovSal, 14 | GBVSIttiKoch, 15 | IttiKoch, 16 | Judd, 17 | ) 18 | from .utils import ExternalModelMixin 19 | -------------------------------------------------------------------------------- /pysaliency/external_models/deepgaze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from ..datasets import as_stimulus 5 | from ..models import Model 6 | from ..utils import as_rgb 7 | 8 | 9 | class StaticDeepGazeModel(Model): 10 | def __init__(self, centerbias_model, device=None, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.centerbias_model = centerbias_model 13 | self.torch_model = self._load_model() 14 | 15 | self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | self.torch_model.to(self.device) 17 | 18 | def _load_model(self): 19 | raise NotImplementedError() 20 | 21 | def _log_density(self, stimulus): 22 | stimulus = as_stimulus(stimulus) 23 | stimulus_data = stimulus.stimulus_data 24 | 25 | stimulus_data = as_rgb(stimulus_data) 26 | stimulus_data = stimulus_data.transpose(2, 0, 1) 27 | 28 | centerbias_data = self.centerbias_model.log_density(stimulus) 29 | 30 | image_tensor = torch.tensor(np.array([stimulus_data]), dtype=torch.float32).to(self.device) 31 | centerbias_tensor = torch.tensor(np.array([centerbias_data]), dtype=torch.float32).to(self.device) 32 | 33 | log_density_prediction = self.torch_model.forward(image_tensor, centerbias_tensor) 34 | 35 | return log_density_prediction.detach().cpu().numpy()[0].astype(np.float64) 36 | 37 | 38 | class DeepGazeI(StaticDeepGazeModel): 39 | """DeepGaze I model 40 | 41 | see https://github.com/matthias-k/DeepGaze and 42 | 43 | DeepGaze I: Kümmerer, M., Theis, L., & Bethge, M. (2015). 44 | Deep Gaze I: Boosting Saliency Prediction with Feature Maps Trained on ImageNet. 45 | ICLR Workshop Track (http://arxiv.org/abs/1411.1045) 46 | """ 47 | def __init__(self, centerbias_model, device=None, *args, **kwargs): 48 | super().__init__(centerbias_model=centerbias_model, *args, **kwargs) 49 | 50 | def _load_model(self): 51 | return torch.hub.load('matthias-k/DeepGaze', 'DeepGazeI', pretrained=True) 52 | 53 | 54 | class DeepGazeIIE(StaticDeepGazeModel): 55 | """DeepGaze IIE model 56 | 57 | see https://github.com/matthias-k/DeepGaze and 58 | 59 | DeepGaze IIE: Linardos, A., Kümmerer, M., Press, O., & Bethge, M. (2021). 60 | Calibrated prediction in and out-of-domain for state-of-the-art saliency modeling. 61 | ICCV 2021 (http://arxiv.org/abs/2105.12441) 62 | """ 63 | def __init__(self, centerbias_model, device=None, *args, **kwargs): 64 | super().__init__(centerbias_model=centerbias_model, *args, **kwargs) 65 | 66 | def _load_model(self): 67 | return torch.hub.load('matthias-k/DeepGaze', 'DeepGazeIIE', pretrained=True) 68 | 69 | def _log_density(self, stimulus): 70 | return super()._log_density(stimulus)[0] 71 | -------------------------------------------------------------------------------- /pysaliency/external_models/models.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division, unicode_literals 2 | 3 | import os 4 | import tempfile 5 | import zipfile 6 | import tarfile 7 | from pkg_resources import resource_string, resource_listdir 8 | 9 | from boltons.fileutils import mkdir_p 10 | import numpy as np 11 | from scipy.ndimage import zoom 12 | 13 | from ..utils import download_and_check, run_matlab_cmd 14 | from ..quilt import QuiltSeries 15 | from ..saliency_map_models import MatlabSaliencyMapModel, SaliencyMapModel 16 | 17 | from .utils import write_file, extract_zipfile, unpack_directory, apply_quilt, download_extract_patch, ExternalModelMixin -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/AIM_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = AIM_wrapper(filename, outname, convolve, filters) 2 | 3 | addpath('AIM'); 4 | 5 | saliency_map = AIM(filename, 1.0, convolve, filters); 6 | save(outname, 'saliency_map', '-v6'); 7 | 8 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/BMS/BMS_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = BMS(filename, outname) 2 | 3 | addpath('source') 4 | 5 | directory = sprintf('tmp_%s', int2str(randi(1e15))); 6 | while exist(directory, 'dir') 7 | directory = sprintf('tmp_%s', int2str(randi(1e15))); 8 | end 9 | mkdir(directory); 10 | copyfile(filename, directory) 11 | 12 | output_directory = fullfile(directory, 'output'); 13 | BMS(directory, output_directory, false); 14 | 15 | [path, name, ext] = fileparts(filename); 16 | outfile = fullfile(output_directory, sprintf('%s.png', name)); 17 | saliency_map = imread(outfile); 18 | save(outname, 'saliency_map'); 19 | 20 | rmdir(directory, 's'); 21 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/BMS/patches/adapt_opencv_paths.diff: -------------------------------------------------------------------------------- 1 | Index: src/mex/compile.m 2 | =================================================================== 3 | --- src.orig/mex/compile.m 2013-09-14 14:27:14.000000000 +0200 4 | +++ src/mex/compile.m 2014-07-04 16:12:38.350680006 +0200 5 | @@ -5,8 +5,8 @@ 6 | function compile() 7 | 8 | % set the values 9 | -opts.opencv_include_path = 'C:\opencv240\install\include'; % OpenCV include path 10 | -opts.opencv_lib_path = 'C:\opencv240\install\lib'; % OpenCV lib path 11 | +opts.opencv_include_path = '/usr/include'; % OpenCV include path 12 | +opts.opencv_lib_path = '/usr/lib/x86_64-linux-gnu'; % OpenCV lib path 13 | opts.clean = false; % clean mode 14 | opts.dryrun = false; % dry run mode 15 | opts.verbose = 1; % output verbosity 16 | @@ -23,7 +23,7 @@ 17 | if opts.verbose > 0, disp(cmd); end 18 | if ~opts.dryrun, delete(cmd); end 19 | 20 | - cmd = fullfile('*.obj'); 21 | + cmd = fullfile('*.o'); 22 | if opts.verbose > 0, disp(cmd); end 23 | if ~opts.dryrun, delete(cmd); end 24 | 25 | @@ -49,7 +49,7 @@ 26 | 27 | % Compile the mex file 28 | src = 'mexBMS.cpp'; 29 | -obj = 'BMS.obj MxArray.obj'; 30 | +obj = 'BMS.o MxArray.o'; 31 | cmd = sprintf('mex %s %s %s', mex_flags, src, obj); 32 | if opts.verbose > 0, disp(cmd); end 33 | if ~opts.dryrun, eval(cmd); end 34 | @@ -83,7 +83,7 @@ 35 | 36 | function l = lib_names(L_path) 37 | %LIB_NAMES return library names 38 | - d = dir( fullfile(L_path,'opencv_*.lib') ); 39 | - l = regexp({d.name}, '(opencv_core.+)\.lib|(opencv_imgproc.+)\.lib|(opencv_highgui.+)\.lib', 'tokens', 'once'); 40 | + d = dir( fullfile(L_path,'libopencv_*.so') ); 41 | + l = regexp({d.name}, 'lib(opencv_core.*)\.so|lib(opencv_imgproc.*)\.so|lib(opencv_highgui.*)\.so', 'tokens', 'once'); 42 | l = [l{:}]; 43 | -end 44 | \ No newline at end of file 45 | +end 46 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/BMS/patches/correct_add_path.diff: -------------------------------------------------------------------------------- 1 | Index: src/BMS.m 2 | =================================================================== 3 | --- src.orig/BMS.m 2013-09-14 12:28:06.000000000 +0200 4 | +++ src/BMS.m 2014-07-04 16:44:51.765801441 +0200 5 | @@ -33,7 +33,8 @@ 6 | % **sod** is a boolean value indicating whether to use the salient object 7 | % detection mode 8 | 9 | -addpath('mex/'); 10 | +[directory name ext] = fileparts(mfilename('fullpath')); 11 | +addpath(fullfile(directory, 'mex')); 12 | 13 | if input_dir(end) ~= '/' && input_dir(end) ~= '\' 14 | input_dir = [input_dir,'/']; 15 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/BMS/patches/fix_FileGettor.diff: -------------------------------------------------------------------------------- 1 | Index: src/mex/fileGettor.h 2 | =================================================================== 3 | --- src.orig/mex/fileGettor.h 2013-09-13 19:44:12.000000000 +0200 4 | +++ src/mex/fileGettor.h 2014-12-05 20:34:55.405538000 +0100 5 | @@ -35,8 +35,13 @@ 6 | cout << "Error opening " << directory << endl; 7 | } 8 | 9 | - readdir(dp);//. 10 | - readdir(dp);//.. 11 | + //commented out by Matthias Kuemmerer: 12 | + //readdir does not garantue any order, in 13 | + //particular . and .. are not garanteed to 14 | + //be the first results. 15 | + // 16 | + //readdir(dp);//. 17 | + //readdir(dp);//.. 18 | while ((dirp = readdir(dp)) != NULL) { 19 | string filename(dirp->d_name); 20 | _name_list.push_back(filename); 21 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/BMS/patches/series: -------------------------------------------------------------------------------- 1 | adapt_opencv_paths.diff 2 | correct_add_path.diff 3 | fix_FileGettor.diff 4 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/ContextAwareSaliency_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = ContextAwareSaliency(filename, outname) 2 | 3 | addpath('source') 4 | 5 | file_names{1} = filename; 6 | img = imread(filename); 7 | [nrows ncols cc] = size(img); 8 | MOV = saliency(file_names); 9 | saliency_map = MOV{1}.SaliencyMap; 10 | saliency_map = imresize(saliency_map, [nrows, ncols]); 11 | save(outname, 'saliency_map'); 12 | 13 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/CovSal_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = CovSal_wrapper(filename, outname, size, quantile, centerbias, modeltype) 2 | addpath('saliency') 3 | 4 | % options for saliency estimation 5 | options.size = size; 6 | options.quantile = quantile; 7 | options.centerBias = centerbias; 8 | options.modeltype = modeltype; 9 | saliency_map = saliencymap(filename, options); 10 | save(outname, 'saliency_map'); 11 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/GBVS/GBVSIttiKoch_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = GBVSIttiKoch_wrapper(filename, outname) 2 | 3 | addpath('gbvs') 4 | addpath('gbvs/algsrc') 5 | addpath('gbvs/compile') 6 | addpath('gbvs/initcache') 7 | addpath('gbvs/saltoolbox') 8 | addpath(genpath('gbvs/util')) 9 | 10 | img = imread(filename); 11 | map = ittikochmap(img); 12 | saliency_map = map.master_map_resized; 13 | save(outname, 'saliency_map'); 14 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/GBVS/GBVS_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = GBVS_rwapper(filename, outname) 2 | 3 | addpath('gbvs') 4 | addpath('gbvs/algsrc') 5 | addpath('gbvs/compile') 6 | addpath('gbvs/initcache') 7 | addpath('gbvs/saltoolbox') 8 | addpath(genpath('gbvs/util')) 9 | 10 | img = imread(filename); 11 | map = gbvs(img); 12 | saliency_map = map.master_map_resized; 13 | save(outname, 'saliency_map'); 14 | 15 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/GBVS/patches/get_path: -------------------------------------------------------------------------------- 1 | Index: src/algsrc/initGBVS.m 2 | =================================================================== 3 | --- src.orig/algsrc/initGBVS.m 2013-08-19 15:25:41.513971846 +0200 4 | +++ src/algsrc/initGBVS.m 2013-08-19 15:25:41.505971965 +0200 5 | @@ -27,7 +27,8 @@ 6 | 7 | % weight matrix 8 | if ( ~param.useIttiKochInsteadOfGBVS ) 9 | - load mypath; 10 | + [directory name ext] = fileparts(mfilename('fullpath')); 11 | + [pathroot name ext] = fileparts(directory); 12 | ufile = sprintf('%s__m%s__%s.mat',num2str(salmapsize),num2str(param.multilevels),num2str(param.cyclic_type)); 13 | ufile(ufile==' ') = '_'; 14 | ufile = fullfile( pathroot , 'initcache' , ufile ); 15 | Index: src/util/getFeatureMaps.m 16 | =================================================================== 17 | --- src.orig/util/getFeatureMaps.m 2013-08-19 15:25:41.513971846 +0200 18 | +++ src/util/getFeatureMaps.m 2013-08-19 15:25:41.509971905 +0200 19 | @@ -4,7 +4,8 @@ 20 | % this computes feature maps for each cannnel in featureChannels/ 21 | % 22 | 23 | -load mypath; 24 | +[directory name ext] = fileparts(mfilename('fullpath')); 25 | +[pathroot name ext] = fileparts(directory); 26 | 27 | %%%% 28 | %%%% STEP 1 : form image pyramid and prune levels if pyramid levels get too small. 29 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/GBVS/patches/make_mex_files_octave_compatible: -------------------------------------------------------------------------------- 1 | Index: src/algsrc/mexArrangeLinear.cc 2 | =================================================================== 3 | --- src.orig/algsrc/mexArrangeLinear.cc 2010-02-19 23:54:34.000000000 +0100 4 | +++ src/algsrc/mexArrangeLinear.cc 2015-01-05 17:48:52.028027000 +0100 5 | @@ -2,7 +2,6 @@ 6 | #include 7 | #include 8 | #include 9 | -#include 10 | #include 11 | 12 | // Avalues = mexArrangeLinear( A , dims ) 13 | Index: src/algsrc/mexAssignWeights.cc 14 | =================================================================== 15 | --- src.orig/algsrc/mexAssignWeights.cc 2010-02-19 23:54:34.000000000 +0100 16 | +++ src/algsrc/mexAssignWeights.cc 2015-01-05 17:48:55.823430000 +0100 17 | @@ -2,7 +2,6 @@ 18 | #include 19 | #include 20 | #include 21 | -#include 22 | #include 23 | 24 | // mexAssignWeights( AL , D , MM , algtype ) 25 | Index: src/algsrc/mexColumnNormalize.cc 26 | =================================================================== 27 | --- src.orig/algsrc/mexColumnNormalize.cc 2010-02-19 23:54:34.000000000 +0100 28 | +++ src/algsrc/mexColumnNormalize.cc 2015-01-05 17:48:59.834418000 +0100 29 | @@ -2,7 +2,6 @@ 30 | #include 31 | #include 32 | #include 33 | -#include 34 | #include 35 | 36 | // Normalizes so that each column sums to one 37 | Index: src/algsrc/mexSumOverScales.cc 38 | =================================================================== 39 | --- src.orig/algsrc/mexSumOverScales.cc 2010-02-19 23:54:34.000000000 +0100 40 | +++ src/algsrc/mexSumOverScales.cc 2015-01-05 17:49:04.403376000 +0100 41 | @@ -3,7 +3,6 @@ 42 | #include 43 | #include 44 | #include 45 | -#include 46 | #include 47 | 48 | // Vo = mexSumOverScales( v , lx , N ) 49 | Index: src/algsrc/mexVectorToMap.cc 50 | =================================================================== 51 | --- src.orig/algsrc/mexVectorToMap.cc 2010-02-19 23:54:36.000000000 +0100 52 | +++ src/algsrc/mexVectorToMap.cc 2015-01-05 17:49:15.931328000 +0100 53 | @@ -3,7 +3,6 @@ 54 | #include 55 | #include 56 | #include 57 | -#include 58 | #include 59 | 60 | // outmap = mexVectorToMap( v , dim ) 61 | Index: src/saltoolbox/mexLocalMaximaGBVS.cc 62 | =================================================================== 63 | --- src.orig/saltoolbox/mexLocalMaximaGBVS.cc 2010-02-19 23:54:52.000000000 +0100 64 | +++ src/saltoolbox/mexLocalMaximaGBVS.cc 2015-01-05 17:49:31.381773000 +0100 65 | @@ -2,7 +2,6 @@ 66 | #include 67 | #include 68 | #include 69 | -#include 70 | #include 71 | 72 | double getVal(double* img, int x, int y, int w, int h); 73 | Index: src/saltoolbox/mySubsample.cc 74 | =================================================================== 75 | --- src.orig/saltoolbox/mySubsample.cc 2010-02-19 23:54:54.000000000 +0100 76 | +++ src/saltoolbox/mySubsample.cc 2015-01-05 17:49:35.284199000 +0100 77 | @@ -2,7 +2,6 @@ 78 | #include 79 | #include 80 | #include 81 | -#include 82 | #include 83 | 84 | void lowPass6yDecY(float* sptr, float* rptr, int w, int hs); 85 | Index: src/util/myContrast.cc 86 | =================================================================== 87 | --- src.orig/util/myContrast.cc 2010-02-19 23:54:56.000000000 +0100 88 | +++ src/util/myContrast.cc 2015-01-05 17:49:44.572766000 +0100 89 | @@ -2,7 +2,6 @@ 90 | #include 91 | #include 92 | #include 93 | -#include 94 | #include 95 | 96 | void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) { 97 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/GBVS/patches/series: -------------------------------------------------------------------------------- 1 | get_path 2 | make_mex_files_octave_compatible 3 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/IttiKoch_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = IttiKoch(filename, outname) 2 | 3 | addpath('SaliencyToolbox') 4 | 5 | img = initializeImage(filename); 6 | params = defaultSaliencyParams; 7 | salmap = makeSaliencyMap(img, params); 8 | saliency_map = imresize(salmap.data,img.size(1:2)); 9 | save(outname, 'saliency_map'); 10 | 11 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/FaceDetect_patches/change_opencv_include: -------------------------------------------------------------------------------- 1 | Index: FaceDetect/src/FaceDetect.cpp 2 | =================================================================== 3 | --- FaceDetect.orig/src/FaceDetect.cpp 2008-05-15 12:09:24.000000000 +0200 4 | +++ FaceDetect/src/FaceDetect.cpp 2013-08-19 15:00:33.552389407 +0200 5 | @@ -23,7 +23,7 @@ 6 | #include "mex.h" // Required for the use of MEX files 7 | 8 | // Required for OpenCV 9 | -#include "cv.h" 10 | +#include "opencv.hpp" 11 | 12 | static CvMemStorage* storage = 0; 13 | static CvHaarClassifierCascade* cascade = 0; 14 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/FaceDetect_patches/series: -------------------------------------------------------------------------------- 1 | change_opencv_include 2 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/JuddSaliencyModel_patches/find_cascade_file: -------------------------------------------------------------------------------- 1 | Index: JuddSaliencyModel/findObjectFeatures.m 2 | =================================================================== 3 | --- JuddSaliencyModel.orig/findObjectFeatures.m 2013-08-13 17:43:04.087668365 +0200 4 | +++ JuddSaliencyModel/findObjectFeatures.m 2013-08-13 17:50:50.604698538 +0200 5 | @@ -71,7 +71,7 @@ 6 | Img = double(rgb2gray(img)); 7 | 8 | % Run face detector 9 | -cascade='haarcascade_frontalface_alt2.xml'; % a little noisy a few misses 10 | +cascade=fullfile(directory,'haarcascade_frontalface_alt2.xml'); % a little noisy a few misses 11 | FaceData = FaceDetect(cascade,Img); 12 | 13 | if find(FaceData==0) 14 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/JuddSaliencyModel_patches/locate_FelzenszwalbDetector_files: -------------------------------------------------------------------------------- 1 | Index: JuddSaliencyModel/findObjectFeatures.m 2 | =================================================================== 3 | --- JuddSaliencyModel.orig/findObjectFeatures.m 2010-02-26 21:12:28.000000000 +0100 4 | +++ JuddSaliencyModel/findObjectFeatures.m 2013-08-13 17:51:27.536146740 +0200 5 | @@ -14,6 +14,10 @@ 6 | % Contact: Tilke Judd at 7 | % ---------------------------------------------------------------------- 8 | 9 | +[directory name ext] = fileparts(mfilename('fullpath')); 10 | +carfilename = fullfile(directory, 'FelzenszwalbDetectors', 'car_final.mat'); 11 | +personfilename = fullfile(directory, 'FelzenszwalbDetectors', 'person_final.mat'); 12 | + 13 | [w h c] = size(img); 14 | Cars=zeros(w, h); 15 | People=zeros(w, h); 16 | @@ -23,7 +27,7 @@ 17 | % Find Cars 18 | %-----------------% 19 | fprintf('Finding cars...'); tic; 20 | -load FelzenszwalbDetectors/car_final.mat; % loads a car model 21 | +load(carfilename); % loads a car model 22 | boxes = detect(img, model, 0); 23 | top = nms(boxes, 0.4); 24 | 25 | @@ -44,7 +48,7 @@ 26 | % Find People 27 | %-----------------% 28 | fprintf('Finding people...'); tic; 29 | -load FelzenszwalbDetectors/person_final.mat; % loads a person model 30 | +load(personfilename); % loads a person model 31 | boxes = detect(img, model, 0); 32 | top = nms(boxes, 0.4); 33 | 34 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/JuddSaliencyModel_patches/series: -------------------------------------------------------------------------------- 1 | locate_FelzenszwalbDetector_files 2 | find_cascade_file 3 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/Judd_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = Judd(filename, outname) 2 | 3 | addpath('source/FaceDetect') 4 | addpath('source/FaceDetect/src') 5 | addpath('source/LabelMeToolbox/features') 6 | addpath('source/LabelMeToolbox/imagemanipulation') 7 | addpath('source/matlabPyrTools') 8 | addpath('source/SaliencyToolbox') 9 | addpath('source/voc-release3.1') 10 | addpath('source/JuddSaliencyModel') 11 | 12 | saliency_map = saliency(filename); 13 | save(outname, 'saliency_map'); 14 | 15 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/SaliencyToolbox_patches/enable_unit16: -------------------------------------------------------------------------------- 1 | Index: SaliencyToolbox/centerSurround.m 2 | =================================================================== 3 | --- SaliencyToolbox.orig/centerSurround.m 2013-07-03 17:30:40.000000000 +0200 4 | +++ SaliencyToolbox/centerSurround.m 2013-08-13 16:50:58.958392479 +0200 5 | @@ -33,7 +33,7 @@ 6 | switch class(params.exclusionMask) 7 | case 'struct' 8 | exclusionIdx = (imresize(params.exclusionMask.data,siz,'nearest') ~= 0); 9 | - case {'double','uint8'} 10 | + case {'double','uint8','uint16'} 11 | exclusionIdx = (imresize(params.exclusionMask,siz,'nearest') ~= 0); 12 | case 'logical' 13 | exclusionIdx = imresize(params.exclusionMask,siz,'nearest'); 14 | Index: SaliencyToolbox/guiSaliency.m 15 | =================================================================== 16 | --- SaliencyToolbox.orig/guiSaliency.m 2013-07-03 17:30:40.000000000 +0200 17 | +++ SaliencyToolbox/guiSaliency.m 2013-08-13 16:51:20.074076565 +0200 18 | @@ -67,7 +67,7 @@ 19 | newImg = varargin{1}; 20 | err = ''; 21 | state = 'ImageLoaded'; 22 | - case {'char','uint8','double'} 23 | + case {'char','uint8','uint16','double'} 24 | [newImg,err] = initializeImage(varargin{1}); 25 | otherwise 26 | err = 1; 27 | Index: SaliencyToolbox/initializeImage.m 28 | =================================================================== 29 | --- SaliencyToolbox.orig/initializeImage.m 2013-07-03 17:30:40.000000000 +0200 30 | +++ SaliencyToolbox/initializeImage.m 2013-08-13 16:56:03.041842803 +0200 31 | @@ -47,7 +47,7 @@ 32 | Img.filename = varargin{1}; 33 | Img.data = NaN; 34 | Img.type = 'unknown'; 35 | - case {'uint8','double'} 36 | + case {'uint8','uint16','double'} 37 | Img.filename = NaN; 38 | Img.data = varargin{1}; 39 | Img.type = 'unknown'; 40 | @@ -62,14 +62,14 @@ 41 | case 'char' 42 | Img.data = NaN; 43 | Img.type = varargin{2}; 44 | - case {'uint8','double'} 45 | + case {'uint8','uint16','double'} 46 | Img.data = varargin{2}; 47 | Img.type = 'unknown'; 48 | otherwise 49 | error('Don''t know how to handle image data of class %s.',class(varargin{2})); 50 | end 51 | 52 | - case {'uint8','double'} 53 | + case {'uint8','uint16','double'} 54 | Img.filename = NaN; 55 | Img.data = varargin{1}; 56 | Img.type = varargin{2}; 57 | Index: SaliencyToolbox/loadImage.m 58 | =================================================================== 59 | --- SaliencyToolbox.orig/loadImage.m 2013-07-03 17:30:40.000000000 +0200 60 | +++ SaliencyToolbox/loadImage.m 2013-08-13 16:52:33.556977143 +0200 61 | @@ -19,6 +19,8 @@ 62 | 63 | if isa(Image.data,'uint8') 64 | imgData = im2double(Image.data); 65 | +elseif isa(Image.data,'uint16') 66 | + imgData = im2double(Image.data); 67 | elseif isnan(Image.data) 68 | imgData = im2double(imread(Image.filename)); 69 | else 70 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/SaliencyToolbox_patches/series: -------------------------------------------------------------------------------- 1 | enable_unit16 2 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/voc_patches/change_fconv: -------------------------------------------------------------------------------- 1 | Index: voc-release3.1/compile.m 2 | =================================================================== 3 | --- voc-release3.1.orig/compile.m 2009-06-09 04:26:04.000000000 +0200 4 | +++ voc-release3.1/compile.m 2013-08-13 15:37:36.404434669 +0200 5 | @@ -6,8 +6,8 @@ 6 | % 1 is fastest, 3 is slowest 7 | 8 | % 1) multithreaded convolution using blas 9 | -mex -O fconvblas.cc -lmwblas -o fconv 10 | +% mex -O fconvblas.cc -lmwblas -o fconv 11 | % 2) mulththreaded convolution without blas 12 | % mex -O fconvMT.cc -o fconv 13 | % 3) basic convolution, very compatible 14 | -% mex -O fconv.cc -o fconv 15 | +mex -O fconv.cc -o fconv 16 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/voc_patches/matlabR2014a_compatible: -------------------------------------------------------------------------------- 1 | Index: voc-release3.1/compile.m 2 | =================================================================== 3 | --- voc-release3.1.orig/compile.m 2015-01-07 22:06:26.797893000 +0100 4 | +++ voc-release3.1/compile.m 2015-01-07 22:07:09.316640000 +0100 5 | @@ -6,8 +6,8 @@ 6 | % 1 is fastest, 3 is slowest 7 | 8 | % 1) multithreaded convolution using blas 9 | -% mex -O fconvblas.cc -lmwblas -o fconv 10 | +% mex -O fconvblas.cc -lmwblas -output fconv 11 | % 2) mulththreaded convolution without blas 12 | -% mex -O fconvMT.cc -o fconv 13 | +% mex -O fconvMT.cc -output fconv 14 | % 3) basic convolution, very compatible 15 | -mex -O fconv.cc -o fconv 16 | +mex -O fconv.cc -output fconv 17 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/voc_patches/matlabR2021a_compatible: -------------------------------------------------------------------------------- 1 | Index: voc-release3.1/resize.cc 2 | =================================================================== 3 | --- voc-release3.1.orig/resize.cc 2009-05-19 16:13:23.000000000 +0200 4 | +++ voc-release3.1/resize.cc 2023-06-13 23:11:21.000000000 +0200 5 | @@ -82,7 +82,7 @@ 6 | // returns resized image 7 | mxArray *resize(const mxArray *mxsrc, const mxArray *mxscale) { 8 | double *src = (double *)mxGetPr(mxsrc); 9 | - const int *sdims = mxGetDimensions(mxsrc); 10 | + const mwSize *sdims = mxGetDimensions(mxsrc); 11 | if (mxGetNumberOfDimensions(mxsrc) != 3 || 12 | mxGetClassID(mxsrc) != mxDOUBLE_CLASS) 13 | mexErrMsgTxt("Invalid input"); 14 | @@ -91,7 +91,7 @@ 15 | if (scale > 1) 16 | mexErrMsgTxt("Invalid scaling factor"); 17 | 18 | - int ddims[3]; 19 | + mwSize ddims[3]; 20 | ddims[0] = (int)round(sdims[0]*scale); 21 | ddims[1] = (int)round(sdims[1]*scale); 22 | ddims[2] = sdims[2]; 23 | Index: voc-release3.1/dt.cc 24 | =================================================================== 25 | --- voc-release3.1.orig/dt.cc 2009-05-19 16:13:23.000000000 +0200 26 | +++ voc-release3.1/dt.cc 2023-06-13 23:16:11.000000000 +0200 27 | @@ -47,7 +47,7 @@ 28 | if (mxGetClassID(prhs[0]) != mxDOUBLE_CLASS) 29 | mexErrMsgTxt("Invalid input"); 30 | 31 | - const int *dims = mxGetDimensions(prhs[0]); 32 | + const mwSize *dims = mxGetDimensions(prhs[0]); 33 | double *vals = (double *)mxGetPr(prhs[0]); 34 | double ax = mxGetScalar(prhs[1]); 35 | double bx = mxGetScalar(prhs[2]); 36 | Index: voc-release3.1/features.cc 37 | =================================================================== 38 | --- voc-release3.1.orig/features.cc 2009-05-19 16:13:23.000000000 +0200 39 | +++ voc-release3.1/features.cc 2023-06-13 23:18:18.000000000 +0200 40 | @@ -35,7 +35,7 @@ 41 | // returns HOG features 42 | mxArray *process(const mxArray *mximage, const mxArray *mxsbin) { 43 | double *im = (double *)mxGetPr(mximage); 44 | - const int *dims = mxGetDimensions(mximage); 45 | + const mwSize *dims = mxGetDimensions(mximage); 46 | if (mxGetNumberOfDimensions(mximage) != 3 || 47 | dims[2] != 3 || 48 | mxGetClassID(mximage) != mxDOUBLE_CLASS) 49 | @@ -51,7 +51,7 @@ 50 | double *norm = (double *)mxCalloc(blocks[0]*blocks[1], sizeof(double)); 51 | 52 | // memory for HOG features 53 | - int out[3]; 54 | + mwSize out[3]; 55 | out[0] = max(blocks[0]-2, 0); 56 | out[1] = max(blocks[1]-2, 0); 57 | out[2] = 27+4; 58 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/Judd/voc_patches/series: -------------------------------------------------------------------------------- 1 | change_fconv 2 | matlabR2014a_compatible 3 | matlabR2021a_compatible 4 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/RARE2012_wrapper.m: -------------------------------------------------------------------------------- 1 | function [ ] = RARE2012(filename, outname) 2 | 3 | addpath('source/VisualAttention-Rare2012-55ba7414b971429e5e899ddfa574e4235fc806e6') 4 | 5 | I = im2double(imread(filename)); 6 | saliency_map = rare2012(I); 7 | save(outname, 'saliency_map'); 8 | 9 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/SUN_wrapper.m: -------------------------------------------------------------------------------- 1 | % created by Aykut Erdem 2 | % adapted by Matthias Kuemmerer for pysaliency 3 | 4 | function [ ] = SUN(filename, outname, scale) 5 | 6 | addpath('saliency') 7 | 8 | img = imread(filename); 9 | salmap = saliencyimage(img,scale); 10 | salmap = imresize(salmap,1/scale, 'nearest'); 11 | height = size(salmap,1); 12 | width = size(salmap,2); 13 | ydiff = size(img,1)-size(salmap,1); 14 | xdiff = size(img,2)-size(salmap,2); 15 | ydiff = round(ydiff/ 2); 16 | xdiff = round(xdiff/ 2); 17 | saliency_map = ones(size(img,1),size(img,2))*min(salmap(:)); 18 | saliency_map(ydiff+1:ydiff+height,xdiff+1:xdiff+width) = salmap; 19 | save(outname, 'saliency_map'); 20 | -------------------------------------------------------------------------------- /pysaliency/external_models/scripts/ensure_image_is_color_image.m: -------------------------------------------------------------------------------- 1 | function [new_img] = ensure_image_is_color_image(img) 2 | if length(size(img)) == 2 3 | new_img = ones(size(img,1), size(img, 2), 3, class(img)); 4 | new_img(:,:,1) = img; 5 | new_img(:,:,2) = img; 6 | new_img(:,:,3) = img; 7 | else 8 | new_img = img; 9 | end 10 | end 11 | -------------------------------------------------------------------------------- /pysaliency/external_models/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import os 4 | import tarfile 5 | import tempfile 6 | import zipfile 7 | from tempfile import TemporaryDirectory 8 | 9 | from pkg_resources import resource_listdir, resource_string 10 | 11 | from ..quilt import QuiltSeries 12 | from ..utils import download_and_check 13 | 14 | 15 | def write_file(filename, contents): 16 | """Write contents to file and close file savely""" 17 | with open(filename, 'wb') as f: 18 | f.write(contents) 19 | 20 | 21 | def extract_zipfile(filename, extract_to): 22 | if zipfile.is_zipfile(filename): 23 | z = zipfile.ZipFile(filename) 24 | #os.makedirs(extract_to) 25 | z.extractall(extract_to) 26 | elif tarfile.is_tarfile(filename): 27 | t = tarfile.open(filename) 28 | t.extractall(extract_to) 29 | else: 30 | raise ValueError('Unkown archive type', filename) 31 | 32 | 33 | def unpack_directory(package, resource_name, location): 34 | files = resource_listdir(package, resource_name) 35 | for file in files: 36 | write_file(os.path.join(location, file), 37 | resource_string(package, os.path.join(resource_name, file))) 38 | 39 | 40 | def apply_quilt(source_location, package, resource_name, patch_directory, verbose=True): 41 | """Apply quilt series from package data to source code""" 42 | os.makedirs(patch_directory) 43 | unpack_directory(package, resource_name, patch_directory) 44 | series = QuiltSeries(patch_directory) 45 | series.apply(source_location, verbose=verbose) 46 | 47 | 48 | def download_extract_patch(url, hash, location, location_in_archive=True, patches=None, verify_ssl=True): 49 | """Download, extract and maybe patch code""" 50 | with TemporaryDirectory() as temp_dir: 51 | if not os.path.isdir(temp_dir): 52 | os.makedirs(temp_dir) 53 | archive_name = os.path.basename(url) 54 | download_and_check(url, 55 | os.path.join(temp_dir, archive_name), 56 | hash, 57 | verify_ssl=verify_ssl) 58 | 59 | if location_in_archive: 60 | target = os.path.dirname(os.path.normpath(location)) 61 | else: 62 | target = location 63 | extract_zipfile(os.path.join(temp_dir, archive_name), 64 | target) 65 | 66 | if patches: 67 | parent_directory = os.path.dirname(os.path.normpath(location)) 68 | patch_directory = os.path.join(parent_directory, os.path.basename(patches)) 69 | apply_quilt(location, __name__, os.path.join('scripts', patches), patch_directory) 70 | 71 | 72 | class ExternalModelMixin(object): 73 | """ 74 | Download and cache necessary files. 75 | 76 | If the location is None, a temporary directory will be used. 77 | If the location is not None, the data will be stored in a 78 | subdirectory of location named after `__modelname`. If this 79 | sub directory already exists, the initialization will 80 | not be run. 81 | 82 | After running `setup()`, the actual location will be 83 | stored in `self.location`. 84 | 85 | To make use of this Mixin, overwrite `_setup()` 86 | and run `setup(location)`. 87 | """ 88 | def setup(self, location, *args, **kwargs): 89 | if location is None: 90 | self.location = tempfile.mkdtemp() 91 | self._setup(*args, **kwargs) 92 | else: 93 | self.location = os.path.join(location, self.__modelname__) 94 | if not os.path.exists(self.location): 95 | self._setup(*args, **kwargs) 96 | 97 | def _setup(self, *args, **kwargs): 98 | raise NotImplementedError() -------------------------------------------------------------------------------- /pysaliency/http_models.py: -------------------------------------------------------------------------------- 1 | from .models import ScanpathModel 2 | from PIL import Image 3 | from io import BytesIO 4 | import requests 5 | import json 6 | import numpy as np 7 | import orjson 8 | 9 | from .datasets import as_stimulus 10 | 11 | class HTTPScanpathModel(ScanpathModel): 12 | """ 13 | A scanpath model that uses a HTTP server to make predictions. 14 | 15 | The model is provided with an URL where it expects a server with the following API: 16 | 17 | /conditional_log_density: expects a POST request with a file attachtment `stimulus` 18 | containing the stimulus and a json body containing x_hist, y_hist, t_hist and a dictionary with other attributes 19 | /type: returns the model type and version 20 | """ 21 | def __init__(self, url): 22 | self.url = url 23 | self.check_type() 24 | 25 | @property 26 | def log_density_url(self): 27 | return self.url + "/conditional_log_density" 28 | 29 | @property 30 | def type_url(self): 31 | return self.url + "/type" 32 | 33 | def conditional_log_density(self, stimulus, x_hist, y_hist, t_hist, attributes=None, out=None): 34 | # build request 35 | stimulus_object = as_stimulus(stimulus) 36 | 37 | # TODO: check for file stimuli, in this case use original file to save encoding time 38 | pil_image = Image.fromarray(stimulus_object.stimulus_data) 39 | image_bytes = BytesIO() 40 | pil_image.save(image_bytes, format='png') 41 | 42 | def _convert_attribute(attribute): 43 | if isinstance(attribute, np.ndarray): 44 | return attribute.tolist() 45 | if isinstance(attribute, (np.int64, np.int32)): 46 | return int(attribute) 47 | if isinstance(attribute, (np.float64, np.float32)): 48 | return float(attribute) 49 | return attribute 50 | 51 | json_data = { 52 | "x_hist": x_hist.tolist(), 53 | "y_hist": y_hist.tolist(), 54 | "t_hist": t_hist.tolist(), 55 | "attributes": {key: _convert_attribute(value) for key, value in (attributes or {}).items()} 56 | } 57 | # send request 58 | response = requests.post(f"{self.log_density_url}", data={'json_data': orjson.dumps(json_data)}, files={'stimulus': image_bytes.getvalue()}) 59 | 60 | # parse response 61 | if response.status_code != 200: 62 | raise ValueError(f"Server returned status code {response.status_code}") 63 | 64 | json_data = orjson.loads(response.text) 65 | prediction = np.array(json_data['log_density']) 66 | return prediction 67 | 68 | def check_type(self): 69 | response = requests.get(f"{self.type_url}").json() 70 | if not response['type'] == 'ScanpathModel': 71 | raise ValueError(f"invalid Model type: {response['type']}. Expected 'ScanpathModel'") 72 | if not response['version'] in ['v1.0.0']: 73 | raise ValueError(f"invalid Model type: {response['version']}. Expected 'v1.0.0'") 74 | -------------------------------------------------------------------------------- /pysaliency/metric_optimization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import, unicode_literals 2 | 3 | from .saliency_map_models import SaliencyMapModel 4 | 5 | 6 | class SIMSaliencyMapModel(SaliencyMapModel): 7 | def __init__(self, parent_model, 8 | kernel_size, 9 | train_samples_per_epoch=1000, val_samples=1000, 10 | train_seed=43, val_seed=42, 11 | fixation_count=100, batch_size=50, 12 | max_batch_size=None, 13 | initial_learning_rate=1e-7, 14 | backlook=1, 15 | min_iter=0, 16 | max_iter=1000, 17 | truncate_gaussian=3, 18 | learning_rate_decay_samples=None, 19 | learning_rate_decay_scheme=None, 20 | learning_rate_decay_ratio=0.333333333, 21 | minimum_learning_rate=1e-11, 22 | initial_model=None, 23 | verbose=True, 24 | session_config=None, 25 | library='torch', 26 | **kwargs 27 | ): 28 | super(SIMSaliencyMapModel, self).__init__(**kwargs) 29 | self.parent_model = parent_model 30 | 31 | self.kernel_size = kernel_size 32 | self.train_samples_per_epoch = train_samples_per_epoch 33 | self.val_samples = val_samples 34 | self.train_seed = train_seed 35 | self.val_seed = val_seed 36 | self.fixation_count = fixation_count 37 | self.batch_size = batch_size 38 | self.max_batch_size = max_batch_size 39 | self.initial_learning_rate = initial_learning_rate 40 | self.backlook = backlook 41 | self.min_iter = min_iter 42 | self.max_iter = max_iter 43 | self.truncate_gaussian = truncate_gaussian 44 | self.learning_rate_decay_samples = learning_rate_decay_samples 45 | self.learning_rate_decay_scheme = learning_rate_decay_scheme 46 | self.learning_rate_decay_ratio = learning_rate_decay_ratio 47 | self.minimum_learning_rate = minimum_learning_rate 48 | self.initial_model = initial_model 49 | self.verbose = verbose 50 | self.session_config = session_config 51 | self.library = library 52 | 53 | def _saliency_map(self, stimulus): 54 | log_density = self.parent_model.log_density(stimulus) 55 | 56 | if self.initial_model: 57 | initial_saliency_map = self.initial_model.saliency_map(stimulus) 58 | else: 59 | initial_saliency_map = None 60 | 61 | if self.library.lower() == 'tensorflow': 62 | from .metric_optimization_tf import maximize_expected_sim 63 | elif self.library.lower() == 'torch': 64 | from .metric_optimization_torch import maximize_expected_sim 65 | else: 66 | raise ValueError(self.library) 67 | 68 | 69 | 70 | saliency_map, val_scores = maximize_expected_sim( 71 | log_density, 72 | kernel_size=self.kernel_size, 73 | train_samples_per_epoch=self.train_samples_per_epoch, 74 | val_samples=self.val_samples, 75 | train_seed=self.train_seed, 76 | val_seed=self.val_seed, 77 | fixation_count=self.fixation_count, 78 | batch_size=self.batch_size, 79 | max_batch_size=self.max_batch_size, 80 | verbose=self.verbose, 81 | session_config=self.session_config, 82 | initial_learning_rate=self.initial_learning_rate, 83 | backlook=self.backlook, 84 | min_iter=self.min_iter, 85 | max_iter=self.max_iter, 86 | truncate_gaussian=self.truncate_gaussian, 87 | learning_rate_decay_samples=self.learning_rate_decay_samples, 88 | initial_saliency_map=initial_saliency_map, 89 | learning_rate_decay_scheme=self.learning_rate_decay_scheme, 90 | learning_rate_decay_ratio=self.learning_rate_decay_ratio, 91 | minimum_learning_rate=self.minimum_learning_rate 92 | ) 93 | return saliency_map 94 | -------------------------------------------------------------------------------- /pysaliency/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division, unicode_literals 2 | 3 | import numpy as np 4 | 5 | 6 | def normalize_saliency_map(saliency_map, cdf, cdf_bins): 7 | """ Normalize saliency to make saliency values distributed according to a given CDF 8 | """ 9 | 10 | smap = saliency_map.copy() 11 | shape = smap.shape 12 | smap = smap.flatten() 13 | smap = np.argsort(np.argsort(smap)).astype(float) 14 | smap /= 1.0 * len(smap) 15 | 16 | inds = np.searchsorted(cdf, smap, side='right') 17 | smap = cdf_bins[inds] 18 | smap = smap.reshape(shape) 19 | smap = smap.reshape(shape) 20 | return smap 21 | 22 | 23 | def convert_saliency_map_to_density(saliency_map, minimum_value=0.0): 24 | if saliency_map.min() < 0: 25 | saliency_map = saliency_map - saliency_map.min() 26 | saliency_map = saliency_map + minimum_value 27 | 28 | saliency_map_sum = saliency_map.sum() 29 | if saliency_map_sum: 30 | saliency_map = saliency_map / saliency_map_sum 31 | else: 32 | saliency_map[:] = 1.0 33 | saliency_map /= saliency_map.sum() 34 | 35 | return saliency_map 36 | 37 | 38 | def NSS(saliency_map, xs, ys): 39 | xs = np.asarray(xs, dtype=int) 40 | ys = np.asarray(ys, dtype=int) 41 | saliency_map = np.asarray(saliency_map, dtype=float) 42 | 43 | mean = saliency_map.mean() 44 | std = saliency_map.std() 45 | 46 | value = saliency_map[ys, xs].copy() 47 | value -= mean 48 | 49 | if std: 50 | value /= std 51 | 52 | return value 53 | 54 | 55 | def CC(saliency_map_1, saliency_map_2): 56 | def normalize(saliency_map): 57 | saliency_map = np.asarray(saliency_map, dtype=float) 58 | saliency_map -= saliency_map.mean() 59 | std = saliency_map.std() 60 | 61 | if std: 62 | saliency_map /= std 63 | 64 | return saliency_map, std == 0 65 | 66 | smap1, constant1 = normalize(saliency_map_1.copy()) 67 | smap2, constant2 = normalize(saliency_map_2.copy()) 68 | 69 | if constant1 and not constant2: 70 | return 0.0 71 | else: 72 | return np.corrcoef(smap1.flatten(), smap2.flatten())[0, 1] 73 | 74 | 75 | def probabilistic_image_based_kl_divergence(logp1, logp2, log_regularization=0, quotient_regularization=0): 76 | if log_regularization or quotient_regularization: 77 | return (np.exp(logp2) * np.log(log_regularization + np.exp(logp2) / (np.exp(logp1) + quotient_regularization))).sum() 78 | else: 79 | return (np.exp(logp2) * (logp2 - logp1)).sum() 80 | 81 | 82 | def image_based_kl_divergence(saliency_map_1, saliency_map_2, minimum_value=1e-20, log_regularization=0, quotient_regularization=0): 83 | """ KLDiv. Function is not symmetric. saliency_map_2 is treated as empirical saliency map. """ 84 | log_density_1 = np.log(convert_saliency_map_to_density(saliency_map_1, minimum_value=minimum_value)) 85 | log_density_2 = np.log(convert_saliency_map_to_density(saliency_map_2, minimum_value=minimum_value)) 86 | 87 | return probabilistic_image_based_kl_divergence(log_density_1, log_density_2, log_regularization=log_regularization, quotient_regularization=quotient_regularization) 88 | 89 | 90 | def MIT_KLDiv(saliency_map_1, saliency_map_2): 91 | """ compute image-based KL divergence with same hyperparameters as in Tuebingen/MIT Saliency Benchmark """ 92 | return image_based_kl_divergence( 93 | saliency_map_1, 94 | saliency_map_2, 95 | minimum_value=0, 96 | log_regularization=2.2204e-16, 97 | quotient_regularization=2.2204e-16 98 | ) 99 | 100 | 101 | def SIM(saliency_map_1, saliency_map_2): 102 | """ Compute similiarity metric. """ 103 | density_1 = convert_saliency_map_to_density(saliency_map_1, minimum_value=0) 104 | density_2 = convert_saliency_map_to_density(saliency_map_2, minimum_value=0) 105 | 106 | return np.min([density_1, density_2], axis=0).sum() 107 | -------------------------------------------------------------------------------- /pysaliency/numba_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, unicode_literals, division, absolute_import 2 | 3 | import numba 4 | import numpy as np 5 | 6 | 7 | def fill_fixation_map(fixation_map, fixations, check_bounds=True): 8 | if check_bounds: 9 | if np.any(fixations < 0): 10 | raise ValueError("Negative fixation positions!") 11 | if np.any(fixations[:, 0] >= fixation_map.shape[0]): 12 | raise ValueError("Fixations y positions out of bound!") 13 | if np.any(fixations[:, 1] >= fixation_map.shape[1]): 14 | raise ValueError("Fixations x positions out of bound!") 15 | return _fill_fixation_map(fixation_map, fixations) 16 | 17 | 18 | @numba.jit(nopython=True) 19 | def _fill_fixation_map(fixation_map, fixations): 20 | """fixationmap: 2d array. fixations: Nx2 array of y, x positions""" 21 | for i in range(len(fixations)): 22 | fixation_y, fixation_x = fixations[i] 23 | fixation_map[int(fixation_y), int(fixation_x)] += 1 24 | 25 | 26 | def auc_for_one_positive(positive, negatives): 27 | """ Computes the AUC score of one single positive sample agains many negatives. 28 | 29 | The result is equal to general_roc([positive], negatives)[0], but computes much 30 | faster because one can save sorting the negatives. 31 | """ 32 | return _auc_for_one_positive(positive, np.asarray(negatives)) 33 | 34 | 35 | @numba.jit(nopython=True) 36 | def _auc_for_one_positive(positive, negatives): 37 | """ Computes the AUC score of one single positive sample agains many negatives. 38 | 39 | The result is equal to general_roc([positive], negatives)[0], but computes much 40 | faster because one can save sorting the negatives. 41 | """ 42 | count = 0 43 | for negative in negatives: 44 | if negative < positive: 45 | count += 1 46 | elif negative == positive: 47 | count += 0.5 48 | 49 | return count / len(negatives) 50 | 51 | 52 | def general_roc_numba(positives, negatives, judd=0): 53 | sorted_positives = np.sort(positives)[::-1] 54 | sorted_negatives = np.sort(negatives)[::-1] 55 | 56 | if judd == 0: 57 | all_values = np.hstack([positives, negatives]) 58 | all_values = np.sort(all_values)[::-1] 59 | else: 60 | min_val = min(sorted_positives[len(positives)-1], sorted_negatives[len(negatives)-1]) 61 | max_val = max(sorted_positives[0], sorted_negatives[0]) + 1 62 | all_values = np.hstack((max_val, positives, min_val)) 63 | all_values = np.sort(all_values)[::-1] 64 | 65 | false_positive_rates = np.zeros(len(all_values) + 1) 66 | hit_rates = np.zeros(len(all_values) + 1) 67 | hit_rates, false_positive_rates = _general_roc_numba(all_values, sorted_positives, sorted_negatives, false_positive_rates, hit_rates) 68 | auc = np.trapz(hit_rates, false_positive_rates) 69 | 70 | return auc, hit_rates, false_positive_rates 71 | 72 | 73 | @numba.jit(nopython=True) 74 | def _general_roc_numba(all_values, sorted_positives, sorted_negatives, false_positive_rates, hit_rates): 75 | """calculate ROC score for given values of positive and negative 76 | distribution""" 77 | 78 | positive_count = len(sorted_positives) 79 | negative_count = len(sorted_negatives) 80 | true_positive_count = 0 81 | false_positive_count = 0 82 | for i in range(len(all_values)): 83 | theta = all_values[i] 84 | while true_positive_count < positive_count and sorted_positives[true_positive_count] >= theta: 85 | true_positive_count += 1 86 | while false_positive_count < negative_count and sorted_negatives[false_positive_count] >= theta: 87 | false_positive_count += 1 88 | false_positive_rates[i+1] = float(false_positive_count) / negative_count 89 | hit_rates[i+1] = float(true_positive_count) / positive_count 90 | 91 | return hit_rates, false_positive_rates 92 | 93 | 94 | def general_rocs_per_positive_numba(positives, negatives): 95 | sorted_positives = np.sort(positives) 96 | sorted_negatives = np.sort(negatives) 97 | sorted_inds = np.argsort(positives) 98 | 99 | results = np.empty(len(positives)) 100 | results = _general_rocs_per_positive_numba(sorted_positives, sorted_negatives, sorted_inds, results) 101 | 102 | return results 103 | 104 | 105 | @numba.jit(nopython=True) 106 | def _general_rocs_per_positive_numba(sorted_positives, sorted_negatives, sorted_inds, results): 107 | """calculate ROC scores for each positive against a list of negatives 108 | distribution. The mean over the result will equal the return value of `general_roc`.""" 109 | 110 | true_negatives_count = 0 111 | equal_count = 0 112 | last_theta = -np.inf 113 | negative_count = len(sorted_negatives) 114 | 115 | for i, theta in enumerate(sorted_positives): 116 | 117 | if theta == last_theta: 118 | results[sorted_inds[i]] = (1.0 * true_negatives_count + 0.5 * equal_count) / negative_count 119 | continue 120 | 121 | true_negatives_count = true_negatives_count + equal_count 122 | 123 | while true_negatives_count < negative_count and sorted_negatives[true_negatives_count] < theta: 124 | true_negatives_count += 1 125 | 126 | equal_count = 0 127 | while true_negatives_count + equal_count < negative_count and sorted_negatives[true_negatives_count + equal_count] <= theta: 128 | equal_count += 1 129 | 130 | results[sorted_inds[i]] = (1.0 * true_negatives_count + 0.5 * equal_count) / negative_count 131 | 132 | last_theta = theta 133 | 134 | return results -------------------------------------------------------------------------------- /pysaliency/optpy/README.md: -------------------------------------------------------------------------------- 1 | Copied from https://github.com/matthias-k/optpy since that name is already taken on pypi and I'm not sure whether it's work packaging it anyway. 2 | -------------------------------------------------------------------------------- /pysaliency/optpy/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .optimization import ParameterManager, minimize, LinearConstraint 4 | from .jacobian import FunctionWithApproxJacobian, FunctionWithApproxJacobianCentral 5 | -------------------------------------------------------------------------------- /pysaliency/optpy/jacobian.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Matthias Kuemmerer, 2014 3 | 4 | """ 5 | from __future__ import print_function, division, unicode_literals, absolute_import 6 | 7 | import sys 8 | import numpy as np 9 | 10 | 11 | class FunctionWithApproxJacobian(object): 12 | def __init__(self, func, epsilon, verbose=True): 13 | self._func = func 14 | self.epsilon = epsilon 15 | self.value_cache = {} 16 | self.verbose = verbose 17 | 18 | def __call__(self, x, *args, **kwargs): 19 | key = tuple(x) 20 | if not key in self.value_cache: 21 | self.log('.') 22 | value = self._func(x, *args, **kwargs) 23 | if np.any(np.isnan(value)): 24 | print("Warning! nan function value encountered at {0}".format(x)) 25 | self.value_cache[key] = value 26 | return self.value_cache[key] 27 | 28 | def func(self, x, *args, **kwargs): 29 | if self.verbose: 30 | print(x) 31 | return self(x, *args, **kwargs) 32 | 33 | def log(self, msg): 34 | if self.verbose: 35 | sys.stdout.write(msg) 36 | sys.stdout.flush() 37 | 38 | def jac(self, x, *args, **kwargs): 39 | self.log('G[') 40 | x0 = np.asfarray(x) 41 | #print x0 42 | dxs = np.zeros((len(x0), len(x0) + 1)) 43 | for i in range(len(x0)): 44 | dxs[i, i + 1] = self.epsilon 45 | results = [self(*(x0 + dxs[:, i], ) + args, **kwargs) for i in range(len(x0) + 1)] 46 | jac = np.zeros([len(x0), len(np.atleast_1d(results[0]))]) 47 | for i in range(len(x0)): 48 | jac[i] = (results[i + 1] - results[0]) / self.epsilon 49 | self.log(']') 50 | return jac.transpose() 51 | 52 | 53 | class FunctionWithApproxJacobianCentral(FunctionWithApproxJacobian): 54 | def jac(self, x, *args, **kwargs): 55 | self.log('G[') 56 | x0 = np.asfarray(x) 57 | #print x0 58 | dxs = np.zeros((len(x0), 2*len(x0))) 59 | for i in range(len(x0)): 60 | dxs[i, i] = -self.epsilon 61 | dxs[i, len(x0)+i] = self.epsilon 62 | results = [self(*(x0 + dxs[:, i], ) + args, **kwargs) for i in range(2*len(x0))] 63 | jac = np.zeros([len(x0), len(np.atleast_1d(results[0]))]) 64 | for i in range(len(x0)): 65 | jac[i] = (results[len(x0)+i] - results[i]) / (2*self.epsilon) 66 | self.log(']') 67 | return jac.transpose() 68 | -------------------------------------------------------------------------------- /pysaliency/quilt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to apply quilt patches to files 3 | 4 | This module enables pysaliency to use quilt patches 5 | to patch code from external saliency models. While 6 | in Linux, quilt itself could be used to apply the patches, 7 | in Windows and Mac quilt might not be available and 8 | nontrivial to install for users. 9 | 10 | It does not support all possible patch files but only 11 | the subset of functionality needed by pysaliency. 12 | """ 13 | 14 | from __future__ import absolute_import, print_function, division, unicode_literals 15 | 16 | import os.path 17 | 18 | from .utils import full_split 19 | 20 | 21 | class Hunk(object): 22 | def __init__(self, lines): 23 | meta_data = lines.pop(0) 24 | a, src_data, target_data, b = meta_data.split() 25 | assert a == '@@' 26 | assert b == '@@' 27 | start, length = self.parse_position(src_data) 28 | assert start < 0 29 | self.source_start = -start 30 | self.source_length = length 31 | 32 | start, length = self.parse_position(target_data) 33 | assert start > 0 34 | self.target_start = start 35 | self.target_length = length 36 | 37 | self.lines = lines 38 | 39 | def parse_position(self, position): 40 | start, length = position.split(',') 41 | start = int(start) 42 | length = int(length) 43 | return start, length 44 | 45 | def apply(self, source, target): 46 | src_pos = self.source_start - 1 47 | assert len(target) == self.target_start - 1 48 | for line in self.lines: 49 | type, data = line[0], line[1:] 50 | if type == ' ': 51 | assert source[src_pos] == data 52 | target.append(data) 53 | src_pos += 1 54 | elif type == '-': 55 | assert source[src_pos] == data 56 | src_pos += 1 57 | elif type == '+': 58 | target.append(data) 59 | elif type == '\\': 60 | # Newline stuff, ignore 61 | pass 62 | else: 63 | raise ValueError(line) 64 | assert src_pos == self.source_start + self.source_length - 1 65 | assert len(target) == self.target_start + self.target_length - 1 66 | return src_pos 67 | 68 | 69 | class Diff(object): 70 | def __init__(self, lines): 71 | source = lines.pop(0) 72 | assert source.startswith('--- ') 73 | _, source = source.split('--- ', 1) 74 | source, _ = source.split('\t', 1) 75 | source = os.path.join(*full_split(source)[1:]) 76 | target = lines.pop(0) 77 | assert target.startswith('+++ ') 78 | _, target = target.split('+++ ', 1) 79 | target, _ = target.split('\t', 1) 80 | target = os.path.join(*full_split(target)[1:]) 81 | self.source_filename = source 82 | self.target_filename = target 83 | self.hunks = [] 84 | while lines: 85 | assert lines[0].startswith('@@ ') 86 | hunk_lines = [lines.pop(0)] 87 | while lines and not lines[0].startswith('@@ '): 88 | line = lines.pop(0) 89 | if line: 90 | hunk_lines.append(line) 91 | self.hunks.append(Hunk(hunk_lines)) 92 | 93 | def apply(self, location): 94 | hunks = list(self.hunks) 95 | source = open(os.path.join(location, self.source_filename)).read() 96 | source = source.split('\n') 97 | target = [] 98 | src_pos = 0 99 | while src_pos < len(source): 100 | if hunks: 101 | if hunks[0].source_start == src_pos+1: 102 | hunk = hunks.pop(0) 103 | src_pos = hunk.apply(source, target) 104 | continue 105 | target.append(source[src_pos]) 106 | src_pos += 1 107 | open(os.path.join(location, self.target_filename), 'w').write('\n'.join(target)) 108 | 109 | 110 | class PatchFile(object): 111 | def __init__(self, patch): 112 | self.diffs = [] 113 | lines = patch.split('\n') 114 | while lines: 115 | index1 = lines.pop(0) 116 | assert index1.startswith('Index: ') 117 | index2 = lines.pop(0) 118 | assert index2.startswith('==============') 119 | diff = [] 120 | diff.append(lines.pop(0)) 121 | while lines and not lines[0].startswith('Index: '): 122 | diff.append(lines.pop(0)) 123 | diff_obj = Diff(diff) 124 | self.diffs.append(diff_obj) 125 | 126 | def apply(self, location, verbose=True): 127 | for diff in self.diffs: 128 | if verbose: 129 | print("Patching {}".format(diff.source_filename)) 130 | diff.apply(location) 131 | 132 | 133 | class QuiltSeries(object): 134 | def __init__(self, patches_location): 135 | self.patches_location = patches_location 136 | series = open(os.path.join(self.patches_location, 'series')).read() 137 | self.patches = [] 138 | self.patch_names = [] 139 | for line in series.split('\n'): 140 | if not line: 141 | continue 142 | patch_content = open(os.path.join(self.patches_location, line)).read() 143 | self.patches.append(PatchFile(patch_content)) 144 | self.patch_names.append(line) 145 | 146 | def apply(self, location, verbose=True): 147 | for patch, name in zip(self.patches, self.patch_names): 148 | if verbose: 149 | print("Applying {}".format(name)) 150 | patch.apply(location, verbose=verbose) 151 | -------------------------------------------------------------------------------- /pysaliency/roc.py: -------------------------------------------------------------------------------- 1 | from .numba_utils import general_roc_numba as general_roc, general_rocs_per_positive_numba as general_rocs_per_positive 2 | -------------------------------------------------------------------------------- /pysaliency/roc_cython.pyx: -------------------------------------------------------------------------------- 1 | #%%cython 2 | # Circumvents a bug(?) in cython: 3 | # http://stackoverflow.com/a/13976504 4 | STUFF = "Hi" 5 | 6 | 7 | import numpy as np 8 | cimport numpy as np 9 | cimport cython 10 | 11 | 12 | #Do not check for index errors 13 | @cython.boundscheck(False) 14 | #Do not enable negativ indices 15 | @cython.wraparound(False) 16 | #Use native c division 17 | @cython.cdivision(True) 18 | def real_ROC(image, fixation_data, int judd=0): 19 | fixations_orig = np.zeros_like(image) 20 | fixations_orig[fixation_data] = 1.0 21 | image_1d = image.flatten() 22 | cdef np.ndarray[double, ndim=1] fixations = fixations_orig.flatten() 23 | inds = image_1d.argsort() 24 | #image_1d = image_1d[inds] 25 | fixations = fixations[inds] 26 | cdef np.ndarray[double, ndim=1] image_sorted = image_1d[inds] 27 | cdef np.ndarray[double, ndim=1] fixation_values_sorted = image[fixation_data] 28 | fixation_values_sorted.sort() 29 | cdef int i 30 | cdef int N = image_1d.shape[0] 31 | cdef int fix_count = fixations.sum() 32 | cdef int false_count = N-fix_count 33 | cdef int correct_count = 0 34 | cdef int false_positive_count = 0 35 | cdef int length 36 | if judd: 37 | length = fix_count+2 38 | assert len(fixation_values_sorted) == fix_count 39 | else: 40 | length = N+1 41 | cdef np.ndarray[double, ndim=1] precs = np.zeros(length) 42 | cdef np.ndarray[double, ndim=1] false_positives = np.zeros(length) 43 | for i in range(N): 44 | #print fixations[N-i-1], 45 | #print image_1d[N-i-1] 46 | # Every pixel is a nonfixation 47 | false_positive_count += 1 48 | if fixations[N-i-1]: 49 | correct_count += 1 50 | if judd: 51 | if i == N - 1 or fixation_values_sorted[N-i-1] != fixation_values_sorted[N-i-2]: 52 | precs[correct_count] = float(correct_count)/fix_count 53 | false_positives[correct_count] = float(false_positive_count)/false_count 54 | else: 55 | precs[correct_count] = precs[correct_count - 1] 56 | false_positives[correct_count] = false_positives[correct_count - 1] 57 | if not judd: 58 | if i == N-1 or image_sorted[N-i-1] != image_sorted[N-i-2]: 59 | precs[i+1] = float(correct_count)/fix_count 60 | false_positives[i+1] = float(false_positive_count)/false_count 61 | else: 62 | precs[i+1] = precs[i] 63 | false_positives[i+1] = false_positives[i] 64 | #print false_positives[i+1] 65 | precs[length-1] = 1.0 66 | false_positives[length-1] = 1.0 67 | aoc = np.trapz(precs, false_positives) 68 | return aoc, precs, false_positives 69 | 70 | 71 | #Do not check for index errors 72 | @cython.boundscheck(False) 73 | #Do not enable negativ indices 74 | @cython.wraparound(False) 75 | #Use native c division 76 | @cython.cdivision(True) 77 | def general_roc(np.ndarray[double, ndim=1] positives, np.ndarray[double, ndim=1] negatives, int judd=0): 78 | """calculate ROC score for given values of positive and negative 79 | distribution""" 80 | cdef np.ndarray[double, ndim=1] sorted_positives = np.sort(positives)[::-1] 81 | cdef np.ndarray[double, ndim=1] sorted_negatives = np.sort(negatives)[::-1] 82 | cdef np.ndarray[double, ndim=1] all_values 83 | if judd == 0: 84 | all_values = np.hstack([positives, negatives]) 85 | all_values = np.sort(all_values)[::-1] 86 | else: 87 | min_val = np.min([sorted_positives[len(positives)-1], sorted_negatives[len(negatives)-1]]) 88 | max_val = np.max([sorted_positives[0], sorted_negatives[0]])+1 89 | all_values = np.hstack((max_val, positives, min_val)) 90 | all_values = np.sort(all_values)[::-1] 91 | cdef np.ndarray[double, ndim=1] false_positive_rates = np.zeros(len(all_values)+1) 92 | cdef np.ndarray[double, ndim=1] hit_rates = np.zeros(len(all_values)+1) 93 | cdef int true_positive_count = 0 94 | cdef int false_positive_count = 0 95 | cdef int positive_count = len(positives) 96 | cdef int negative_count = len(negatives) 97 | cdef int i 98 | cdef double theta 99 | for i in range(len(all_values)): 100 | theta = all_values[i] 101 | while true_positive_count < positive_count and sorted_positives[true_positive_count] >= theta: 102 | true_positive_count += 1 103 | while false_positive_count < negative_count and sorted_negatives[false_positive_count] >= theta: 104 | false_positive_count += 1 105 | false_positive_rates[i+1] = float(false_positive_count) / negative_count 106 | hit_rates[i+1] = float(true_positive_count) / positive_count 107 | auc = np.trapz(hit_rates, false_positive_rates) 108 | return auc, hit_rates, false_positive_rates 109 | 110 | 111 | #Do not check for index errors 112 | @cython.boundscheck(False) 113 | #Do not enable negativ indices 114 | @cython.wraparound(False) 115 | #Use native c division 116 | @cython.cdivision(True) 117 | def general_rocs_per_positive(np.ndarray[double, ndim=1] positives, np.ndarray[double, ndim=1] negatives): 118 | """calculate ROC scores for each positive against a list of negatives 119 | distribution. The mean over the result will equal the return value of `general_roc`.""" 120 | cdef np.ndarray[double, ndim=1] sorted_positives = np.sort(positives) 121 | cdef np.ndarray[double, ndim=1] sorted_negatives = np.sort(negatives) 122 | cdef np.ndarray[long, ndim=1] sorted_inds = np.argsort(positives) 123 | 124 | cdef np.ndarray[double, ndim=1] results = np.empty(len(positives)) 125 | cdef int true_positive_count = 0 126 | cdef int false_positive_count = 0 127 | cdef int true_negative_count = 0 128 | cdef int positive_count = len(positives) 129 | cdef int negative_count = len(negatives) 130 | cdef int i 131 | cdef double last_theta = -np.inf 132 | cdef double theta 133 | 134 | cdef int true_negatives_count = 0 135 | cdef int equal_count = 0 136 | for i in range(len(sorted_positives)): 137 | theta = sorted_positives[i] 138 | #print('theta', theta) 139 | if theta == last_theta: 140 | #print('same') 141 | results[sorted_inds[i]] = (1.0*true_negatives_count + 0.5*equal_count) / negative_count 142 | continue 143 | 144 | true_negatives_count = true_negatives_count + equal_count 145 | 146 | while true_negatives_count < negative_count and sorted_negatives[true_negatives_count] < theta: 147 | true_negatives_count += 1 148 | #print('.') 149 | equal_count = 0 150 | while true_negatives_count + equal_count < negative_count and sorted_negatives[true_negatives_count+equal_count] <= theta: 151 | equal_count += 1 152 | #print('=') 153 | results[sorted_inds[i]] = (1.0*true_negatives_count + 0.5*equal_count) / negative_count 154 | 155 | last_theta = theta 156 | return results 157 | -------------------------------------------------------------------------------- /pysaliency/saliency_map_conversion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from __future__ import absolute_import, division, print_function # , unicode_literals 3 | 4 | import numpy as np 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def optimize_for_information_gain( 10 | model, fit_stimuli, fit_fixations, 11 | nonlinearity_target='density', 12 | nonlinearity_values='logdensity', 13 | num_nonlinearity=20, 14 | num_centerbias=12, 15 | blur_radius=0, 16 | optimize=None, 17 | average='image', 18 | saliency_min=None, 19 | saliency_max=None, 20 | batch_size=1, 21 | verbose=0, 22 | return_optimization_result=False, 23 | maxiter=1000, 24 | method='trust-constr', 25 | minimize_options=None, 26 | cache_directory=None, 27 | framework='torch'): 28 | """ convert saliency map model into probabilistic model as described in Kümmerer et al, PNAS 2015. 29 | """ 30 | 31 | if saliency_min is None or saliency_max is None: 32 | smax = -np.inf 33 | smin = np.inf 34 | for s in tqdm(fit_stimuli, disable=verbose < 2): 35 | smap = model.saliency_map(s) 36 | smax = np.max([smax, smap.max()]) 37 | smin = np.min([smin, smap.min()]) 38 | 39 | if saliency_min is None: 40 | saliency_min = smin 41 | if saliency_max is None: 42 | saliency_max = smax 43 | 44 | if framework == 'theano': 45 | assert nonlinearity_target == 'density' 46 | assert nonlinearity_values == 'logdensity' 47 | assert average == 'fixations' 48 | assert batch_size == 1 49 | assert minimize_options is None 50 | assert cache_directory is None 51 | 52 | from .saliency_map_conversion_theano import optimize_for_information_gain 53 | return optimize_for_information_gain( 54 | model, fit_stimuli, fit_fixations, 55 | num_nonlinearity=num_nonlinearity, 56 | num_centerbias=num_centerbias, 57 | blur_radius=blur_radius, 58 | optimize=optimize, 59 | saliency_min=saliency_min, 60 | saliency_max=saliency_max, 61 | verbose=verbose, 62 | return_optimization_result=return_optimization_result, 63 | maxiter=maxiter, 64 | method=method 65 | ) 66 | elif framework == 'torch': 67 | from .saliency_map_conversion_torch import optimize_saliency_map_conversion 68 | return optimize_saliency_map_conversion( 69 | model, fit_stimuli, fit_fixations, 70 | nonlinearity_target=nonlinearity_target, 71 | nonlinearity_values=nonlinearity_values, 72 | saliency_min=saliency_min, 73 | saliency_max=saliency_max, 74 | optimize=optimize, 75 | average=average, 76 | batch_size=batch_size, 77 | num_nonlinearity=num_nonlinearity, 78 | num_centerbias=num_centerbias, 79 | blur_radius=blur_radius, 80 | verbose=verbose, 81 | return_optimization_result=return_optimization_result, 82 | maxiter=maxiter, 83 | minimize_options=minimize_options, 84 | cache_directory=cache_directory, 85 | method=method 86 | ) 87 | -------------------------------------------------------------------------------- /pysaliency/sampling_models.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division, unicode_literals 2 | 3 | from abc import ABCMeta, abstractmethod 4 | 5 | from .utils import remove_trailing_nans 6 | 7 | 8 | class SamplingModelMixin(object, metaclass=ABCMeta): 9 | """A sampling model is supports sampling fixations and whole scanpaths.""" 10 | def sample_scanpath( 11 | self, stimulus, x_hist, y_hist, t_hist, samples, attributes=None, verbose=False, rst=None 12 | ): 13 | """return xs, ys, ts""" 14 | xs = list(remove_trailing_nans(x_hist)) 15 | ys = list(remove_trailing_nans(y_hist)) 16 | ts = list(remove_trailing_nans(t_hist)) 17 | if not len(xs) == len(ys) == len(ts): 18 | raise ValueError("Histories for x, y and t have to be the same length") 19 | 20 | for i in range(samples): 21 | x, y, t = self.sample_fixation(stimulus, xs, ys, ts, attributes=attributes, verbose=verbose, rst=rst) 22 | xs.append(x) 23 | ys.append(y) 24 | ts.append(t) 25 | 26 | return xs, ys, ts 27 | 28 | @abstractmethod 29 | def sample_fixation(self, stimulus, x_hist, y_hist, t_hist, attributes=None, verbose=False, rst=None): 30 | """return x, y, t""" 31 | raise NotImplementedError() 32 | 33 | 34 | class ScanpathSamplingModelMixin(SamplingModelMixin): 35 | """A sampling model which only has to implement sample_scanpath instead of sample_fixation""" 36 | @abstractmethod 37 | def sample_scanpath( 38 | self, stimulus, x_hist, y_hist, t_hist, samples, attributes=None, verbose=False, rst=None 39 | ): 40 | raise NotImplementedError() 41 | 42 | def sample_fixation(self, stimulus, x_hist, y_hist, t_hist, attributes=None, verbose=False, rst=None): 43 | samples = 1 44 | xs, ys, ts = self.sample_scanpath(stimulus, x_hist, y_hist, t_hist, samples, attributes=attributes, 45 | verbose=verbose, rst=rst) 46 | return xs[-1], ys[-1], ts[-1] 47 | -------------------------------------------------------------------------------- /pysaliency/tf_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import, unicode_literals 2 | 3 | import tensorflow as tf 4 | slim = tf.contrib.slim 5 | 6 | 7 | def normalize_axis(input_tensor, axis): 8 | """ convert negative indices into positive indices since tensorflow can't handle them """ 9 | if axis < 0: 10 | ndims = len(input_tensor.get_shape()) 11 | axis = ndims + axis 12 | return axis 13 | 14 | 15 | def replication_padding(input_tensor, axis=0, size=1): 16 | """ add replication padding to a tensor along a given axis """ 17 | with tf.name_scope('replication_padding'): 18 | if not isinstance(size, (tuple, list)): 19 | size = (size, size) 20 | ndims = len(input_tensor.get_shape()) 21 | axis = normalize_axis(input_tensor, axis) 22 | start_slice_obj = [slice(None)] * axis + [slice(0, 1)] 23 | start_slice = input_tensor[start_slice_obj] 24 | repeats = [1] * axis + [size[0]] + [1] * (ndims-axis-1) 25 | start_part = tf.tile(start_slice, repeats) 26 | end_slice_obj = [slice(None)] * axis + [slice(-1, None)] 27 | end_slice = input_tensor[end_slice_obj] 28 | repeats = [1] * axis + [size[1]] + [1] * (ndims-axis-1) 29 | end_part = tf.tile(end_slice, repeats) 30 | return tf.concat((start_part, input_tensor, end_part), axis=axis) 31 | 32 | 33 | def get_gaussian_kernel(sigma, windowradius=5): 34 | with tf.name_scope('gaussian_kernel'): 35 | kernel = tf.cast(tf.range(0, 2*windowradius+1), 'float') - windowradius 36 | kernel = tf.exp(-(kernel**2)/(2*sigma**2)) 37 | kernel /= tf.reduce_sum(kernel) 38 | return kernel 39 | 40 | 41 | def blowup_1d_kernel(kernel, axis=-1): 42 | #with tf.name_scope("blowup_1d_kernel") 43 | assert isinstance(axis, int) 44 | 45 | shape = [1 for i in range(4)] 46 | shape[axis] = -1 47 | return tf.reshape(kernel, shape) 48 | 49 | 50 | @slim.add_arg_scope 51 | def gaussian_convolution_along_axis(inputs, axis, sigma, windowradius=5, mode='NEAREST', scope=None, 52 | outputs_collections=None): 53 | with tf.name_scope(scope, 'gauss_1d', [inputs, sigma, windowradius]): 54 | if mode == 'NEAREST': 55 | inputs = replication_padding(inputs, axis=axis+1, size=windowradius) 56 | elif mode == 'ZERO': 57 | paddings = [[0, 0], [0, 0], [0, 0], [0, 0]] 58 | paddings[axis+1] = [windowradius, windowradius] 59 | inputs = tf.pad(inputs, paddings) 60 | elif mode == 'VALID': 61 | pass 62 | else: 63 | raise ValueError(mode) 64 | 65 | kernel_1d = get_gaussian_kernel(sigma, windowradius=windowradius) 66 | kernel = blowup_1d_kernel(kernel_1d, axis) 67 | #print(windowradius) 68 | 69 | output = tf.nn.conv2d(inputs, kernel, 70 | strides=[1, 1, 1, 1], padding="VALID", name='gaussian_convolution') 71 | return output 72 | 73 | #return slim.utils.collect_named_outputs(outputs_collections, sc, output) 74 | 75 | 76 | @slim.add_arg_scope 77 | def gauss_blur(inputs, sigma, windowradius=5, mode='NEAREST', scope=None, 78 | outputs_collections=None): 79 | with tf.name_scope(scope, 'gauss_blur', [inputs, sigma, windowradius]) as sc: 80 | 81 | outputs = inputs 82 | 83 | for axis in [0, 1]: 84 | 85 | outputs = gaussian_convolution_along_axis(outputs, 86 | axis=axis, 87 | sigma=sigma, 88 | windowradius=windowradius, 89 | mode=mode) 90 | return outputs 91 | 92 | return slim.utils.collect_named_outputs(outputs_collections, sc, outputs) 93 | 94 | 95 | def tf_logsumexp(data, axis=0): 96 | """computes logsumexp along axis in its own graph and session""" 97 | with tf.Graph().as_default() as g: 98 | input_tensor = tf.placeholder(tf.float32, name='input_tensor') 99 | output_tensor = tf.reduce_logsumexp(input_tensor, axis=axis) 100 | 101 | with tf.Session(graph=g) as sess: 102 | return sess.run(output_tensor, {input_tensor: data}) 103 | -------------------------------------------------------------------------------- /pysaliency/torch_datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from boltons.iterutils import chunked 6 | from torch.utils.data.dataloader import default_collate 7 | from tqdm import tqdm 8 | 9 | from .models import Model 10 | from .saliency_map_models import SaliencyMapModel 11 | 12 | 13 | def ensure_color_image(image): 14 | if len(image.shape) == 2: 15 | return np.dstack([image, image, image]) 16 | return image 17 | 18 | 19 | def x_y_to_sparse_indices(xs, ys): 20 | # Converts list of x and y coordinates into indices and values for sparse mask 21 | x_inds = [] 22 | y_inds = [] 23 | values = [] 24 | pair_inds = {} 25 | 26 | for x, y in zip(xs, ys): 27 | key = (x, y) 28 | if key not in pair_inds: 29 | x_inds.append(x) 30 | y_inds.append(y) 31 | pair_inds[key] = len(x_inds) - 1 32 | values.append(1) 33 | else: 34 | values[pair_inds[key]] += 1 35 | 36 | return np.array([y_inds, x_inds]), values 37 | 38 | 39 | class ImageDataset(torch.utils.data.Dataset): 40 | def __init__(self, stimuli, fixations, models=None, transform=None, cached=True, average='fixation'): 41 | self.stimuli = stimuli 42 | self.fixations = fixations 43 | 44 | if models is None: 45 | models = {} 46 | 47 | self.models = models 48 | self.transform = transform 49 | self.average = average 50 | 51 | self.cached = cached 52 | if cached: 53 | self._cache = {} 54 | print("Populating fixations cache") 55 | self._xs_cache = {} 56 | self._ys_cache = {} 57 | 58 | for x, y, n in zip(self.fixations.x_int, self.fixations.y_int, tqdm(self.fixations.n)): 59 | self._xs_cache.setdefault(n, []).append(x) 60 | self._ys_cache.setdefault(n, []).append(y) 61 | 62 | for key in list(self._xs_cache): 63 | self._xs_cache[key] = np.array(self._xs_cache[key], dtype=int) 64 | for key in list(self._ys_cache): 65 | self._ys_cache[key] = np.array(self._ys_cache[key], dtype=int) 66 | 67 | def get_shapes(self): 68 | return list(self.stimuli.sizes) 69 | 70 | def __getitem__(self, key): 71 | if not self.cached or key not in self._cache: 72 | image = np.array(self.stimuli.stimuli[key]) 73 | 74 | predictions = {} 75 | for model_name, model in self.models.items(): 76 | if isinstance(model, Model): 77 | prediction = np.asarray(model.log_density(image)) 78 | elif isinstance(model, SaliencyMapModel): 79 | prediction = np.asarray(model.saliency_map(image)) 80 | predictions[model_name] = prediction 81 | 82 | image = ensure_color_image(image).astype(np.float32) 83 | image = image.transpose(2, 0, 1) 84 | 85 | if self.cached: 86 | xs = self._xs_cache.pop(key) 87 | ys = self._ys_cache.pop(key) 88 | else: 89 | inds = self.fixations.n == key 90 | xs = np.array(self.fixations.x_int[inds], dtype=int) 91 | ys = np.array(self.fixations.y_int[inds], dtype=int) 92 | data = { 93 | "image": image, 94 | "x": xs, 95 | "y": ys, 96 | } 97 | 98 | for prediction_name, prediction in predictions.items(): 99 | data[prediction_name] = prediction 100 | 101 | if self.average == 'image': 102 | data['weight'] = 1.0 103 | else: 104 | data['weight'] = float(len(xs)) 105 | 106 | if self.cached: 107 | self._cache[key] = data 108 | else: 109 | data = self._cache[key] 110 | 111 | if self.transform is not None: 112 | return self.transform(dict(data)) 113 | 114 | return data 115 | 116 | def __len__(self): 117 | return len(self.stimuli) 118 | 119 | 120 | class FixationMaskTransform(object): 121 | def __call__(self, item): 122 | shape = torch.Size([item['image'].shape[1], item['image'].shape[2]]) 123 | x = item.pop('x') 124 | y = item.pop('y') 125 | 126 | # inds, values = x_y_to_sparse_indices(x, y) 127 | inds = np.array([y, x]) 128 | values = np.ones(len(y), dtype=int) 129 | 130 | # mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape) 131 | mask = torch.sparse_coo_tensor(torch.tensor(inds), torch.tensor(values), shape, dtype=torch.int) 132 | mask = mask.coalesce() 133 | 134 | item['fixation_mask'] = mask 135 | 136 | return item 137 | 138 | 139 | class ImageDatasetSampler(torch.utils.data.Sampler): 140 | def __init__(self, data_source, batch_size=1, ratio_used=1.0, shuffle=True): 141 | self.ratio_used = ratio_used 142 | self.shuffle = shuffle 143 | 144 | shapes = data_source.get_shapes() 145 | unique_shapes = sorted(set(shapes)) 146 | 147 | shape_indices = [[] for shape in unique_shapes] 148 | 149 | for k, shape in enumerate(shapes): 150 | shape_indices[unique_shapes.index(shape)].append(k) 151 | 152 | if self.shuffle: 153 | for indices in shape_indices: 154 | random.shuffle(indices) 155 | 156 | self.batches = sum([chunked(indices, size=batch_size) for indices in shape_indices], []) 157 | 158 | def __iter__(self): 159 | if self.shuffle: 160 | indices = torch.randperm(len(self.batches)) 161 | else: 162 | indices = range(len(self.batches)) 163 | 164 | if self.ratio_used < 1.0: 165 | indices = indices[:int(self.ratio_used * len(indices))] 166 | 167 | return iter(self.batches[i] for i in indices) 168 | 169 | def __len__(self): 170 | return int(self.ratio_used * len(self.batches)) 171 | 172 | 173 | # we need to extend the defaut collate fn to handle sparse coo tensors 174 | def collate_fn(batch): 175 | result = {} 176 | for key in batch[0]: 177 | if isinstance(batch[0][key], torch.sparse.Tensor): 178 | result[key] = torch.stack([item[key] for item in batch], 0) 179 | else: 180 | result[key] = default_collate([item[key] for item in batch]) 181 | return result -------------------------------------------------------------------------------- /pysaliency/utils/variable_length_array.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List 2 | 3 | import numpy as np 4 | 5 | from . import build_padded_2d_array 6 | 7 | 8 | class VariableLengthArray: 9 | """ 10 | Represents a variable length array. 11 | 12 | The following indexing operations are supported: 13 | - Accessing rows: array[i] 14 | - Accessing elements: array[i, j] where j can also be negative to get elements from the end of the row 15 | - Slicing: array[i:j, k] where k can also be negative to get elements from the end of each row 16 | 17 | 18 | Args: 19 | data (Union[np.ndarray, list[list]]): The data for the array. Can be either a numpy array or a list of lists. 20 | lengths (np.ndarray): The lengths of each row in the data array. 21 | 22 | Attributes: 23 | _data (np.ndarray): The internal data array with padded rows. 24 | lengths (np.ndarray): The lengths of each row in the data array. 25 | 26 | Methods: 27 | __len__(): Returns the number of rows in the array. 28 | __getitem__(index): Returns the value(s) at the specified index(es) in the array. 29 | """ 30 | 31 | def __init__(self, data: Union[np.ndarray, List[list]], lengths: Optional[np.ndarray] = None): 32 | """List 33 | Initialize the VariableLengthArray object. 34 | 35 | Args: 36 | data (Union[np.ndarray, list[list]]): The input data, which can be either a numpy array or a list of lists. 37 | lengths (np.ndarray): An array containing the lengths of each row in the data. 38 | 39 | Raises: 40 | ValueError: If the input data shape doesn't match the provided lengths. 41 | 42 | """ 43 | 44 | if lengths is not None: 45 | if len(data) != len(lengths): 46 | raise ValueError(f"The number of rows in the data array has to match the number of elements in lengths ({len(data)} != {len(lengths)})") 47 | 48 | if not isinstance(data, np.ndarray): 49 | for row, length in zip(data, lengths): 50 | if len(row) != length: 51 | raise ValueError(f"The length of row {row} does not match the specified length {length}") 52 | else: 53 | if not data.ndim >= 2: 54 | raise ValueError("If data is a numpy array, it has to be at least 2-dimensional") 55 | if len(lengths) and np.max(lengths) > data.shape[1]: 56 | raise ValueError("The specified lengths are larger than the number of columns in the data array") 57 | 58 | lengths = np.array(lengths, dtype=int) 59 | 60 | else: 61 | if isinstance(data, np.ndarray): 62 | raise ValueError("If data is a numpy array, lengths must be provided") 63 | lengths = np.array([len(row) for row in data]) 64 | 65 | if isinstance(data, np.ndarray): 66 | self._data = data 67 | else: 68 | self._data = build_padded_2d_array(data, max_length=np.max(lengths) if len(lengths) else 0) 69 | 70 | # max_len = np.max(lengths) 71 | # self._data = np.full((len(data), max_len), np.nan) 72 | # for i, row in enumerate(data): 73 | # if len(row) < lengths[i]: 74 | # raise ValueError(f"Row {i} has fewer elements than specified in lengths ({len(row)} < {lengths[i]}") 75 | # self._data[i, :lengths[i]] = row[:lengths[i]] 76 | self.lengths = lengths 77 | 78 | def __len__(self): 79 | return len(self._data) 80 | 81 | def __getitem__(self, index): 82 | if isinstance(index, tuple): 83 | row_idx, col_idx = index 84 | if isinstance(row_idx, slice): 85 | if isinstance(col_idx, int): 86 | return np.array([self._data[i, :self.lengths[i]][col_idx] for i in range(*row_idx.indices(len(self._data)))]) 87 | elif isinstance(col_idx, slice): 88 | # does this work? 89 | return self._data[row_idx, :self.lengths[row_idx]][col_idx] 90 | else: 91 | return self._data[row_idx, :self.lengths[row_idx]][col_idx] 92 | elif isinstance(index, (int, np.integer)): 93 | return self._data[index, :self.lengths[index]] 94 | else: 95 | return VariableLengthArray(self._data[index], self.lengths[index]) 96 | # new_lengths = self.lengths[index] 97 | # max_length = np.max(new_lengths) 98 | # new_data = self._data[index, :max_length] 99 | # return VariableLengthArray(new_data, new_lengths) 100 | 101 | def copy(self) -> 'VariableLengthArray': 102 | return VariableLengthArray(self._data.copy(), self.lengths.copy()) 103 | 104 | def __repr__(self): 105 | representation = "VariableLengthArray(\n" 106 | if len(self) < 10: 107 | for i in range(len(self)): 108 | representation += f" {self[i]}\n" 109 | else: 110 | for i in range(5): 111 | representation += f" {self[i]}\n" 112 | representation += " ...\n" 113 | for i in range(len(self)-5, len(self)): 114 | representation += f" {self[i]}\n" 115 | representation += ")" 116 | return representation 117 | 118 | 119 | def concatenate_variable_length_arrays(arrays: List[VariableLengthArray]) -> VariableLengthArray: 120 | """ 121 | Concatenate a list of VariableLengthArray objects along the first axis. 122 | 123 | Args: 124 | arrays (List[VariableLengthArray]): List of VariableLengthArray objects to concatenate. 125 | 126 | Returns: 127 | VariableLengthArray: The concatenated VariableLengthArray object. 128 | """ 129 | lengths = np.concatenate([array.lengths for array in arrays]) 130 | 131 | datas = [array._data for array in arrays] 132 | max_cols = max(a.shape[1] for a in datas) 133 | padded_datas = [] 134 | for a in datas: 135 | if a.shape[1] < max_cols: 136 | padding = np.empty((a.shape[0], max_cols-a.shape[1]), dtype=a.dtype) 137 | padding[:] = np.nan 138 | padded_datas.append(np.hstack((a, padding))) 139 | else: 140 | padded_datas.append(a) 141 | data = np.vstack(padded_datas) 142 | 143 | return VariableLengthArray(data, lengths) 144 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | slow: marks tests as being slow (skipped by default, select with '--runslow') 4 | nonfree: marks tests as requiring nonpublic data(skipeed by default, select with '--run-nonfree') 5 | theano: marks tests as using theano (deselect with '-m "not theano"') 6 | matlab: marks tests as using matlab (deselect with '-m "not matlab"') 7 | download: marks tests as using downloads (deselect with '-m "not download"') 8 | skip_octave: marks tests to be skipped when using octave instead of matlab 9 | serial 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython 2 | dill 3 | h5py 4 | imageio 5 | natsort 6 | numpy 7 | scipy 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from os import path 3 | 4 | from setuptools import setup, find_packages 5 | from setuptools.extension import Extension 6 | from Cython.Build import cythonize 7 | 8 | import numpy as np 9 | import io 10 | 11 | PACKAGE_NAME = 'pysaliency' 12 | VERSION = '0.2.22' 13 | DESCRIPTION = 'A Python Framework for Saliency Modeling and Evaluation' 14 | AUTHOR = 'Matthias Kümmerer' 15 | EMAIL = 'matthias.kuemmerer@bethgelab.org' 16 | URL = "https://github.com/matthiask/pysaliency" 17 | 18 | try: 19 | this_directory = path.abspath(path.dirname(__file__)) 20 | with io.open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: 21 | long_description = f.read() 22 | except IOError: 23 | long_description = '' 24 | 25 | extensions = [ 26 | Extension("pysaliency.roc_cython", ['pysaliency/*.pyx'], 27 | include_dirs = [np.get_include()], 28 | extra_compile_args = ['-O3'], 29 | #extra_compile_args = ['-fopenmp', '-O3'], 30 | #extra_link_args=["-fopenmp"] 31 | ), 32 | ] 33 | 34 | 35 | setup( 36 | name = PACKAGE_NAME, 37 | version = VERSION, 38 | description = 'python library to develop, evaluate and benchmark saliency models', 39 | long_description = long_description, 40 | long_description_content_type='text/markdown', 41 | classifiers=[ 42 | "Development Status :: 3 - Alpha", 43 | "Intended Audience :: Developers", 44 | "Intended Audience :: Science/Research", 45 | "License :: OSI Approved :: MIT License", 46 | #"Programming Language :: Python :: 2.7", 47 | "Programming Language :: Python :: 3", 48 | #"Programming Language :: Python :: 3.5", 49 | "Programming Language :: Python :: 3.6", 50 | "Topic :: Scientific/Engineering", 51 | ], 52 | packages = find_packages(), 53 | author = AUTHOR, 54 | author_email = EMAIL, 55 | url = URL, 56 | license = 'MIT', 57 | install_requires=[ 58 | 'boltons', 59 | 'deprecation', 60 | 'imageio', 61 | 'natsort', 62 | 'numba', 63 | 'numpy', 64 | 'piexif', 65 | 'requests', 66 | 'schema', 67 | 'scipy', 68 | 'setuptools', 69 | 'tqdm', 70 | ], 71 | include_package_data = True, 72 | package_data={'pysaliency': ['external_models/scripts/*.m', 73 | 'external_models/scripts/*/*.m', 74 | 'external_models/scripts/*/*/*', 75 | 'external_models/scripts/BMS/patches/*', 76 | 'external_models/scripts/GBVS/patches/*', 77 | 'external_models/scripts/Judd/patches/*', 78 | 'external_datasets/scripts/*.m' 79 | ]}, 80 | ext_modules = cythonize(extensions), 81 | ) 82 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from pathlib import Path 4 | from unittest.mock import MagicMock 5 | 6 | import pytest 7 | import requests 8 | from tqdm import tqdm 9 | 10 | # make sure that package can be found when running `pytest` instead of `python -m pytest` 11 | sys.path.insert(0, os.getcwd()) 12 | 13 | def pytest_addoption(parser): 14 | parser.addoption("--runslow", action="store_true", 15 | default=False, help="run slow tests") 16 | parser.addoption("--run-nonfree", action="store_true", 17 | default=False, help="run tests requiring nonpublic data") 18 | parser.addoption("--nomatlab", action="store_true", default=False, help="don't run matlab tests") 19 | parser.addoption("--nooctave", action="store_true", default=False, help="don't run octave tests") 20 | parser.addoption("--notheano", action="store_true", default=False, help="don't run slow theano tests") 21 | parser.addoption("--nodownload", action="store_true", default=False, help="don't download external data") 22 | 23 | 24 | def pytest_collection_modifyitems(config, items): 25 | run_slow = config.getoption("--runslow") 26 | run_nonfree = config.getoption('--run-nonfree') 27 | no_matlab = config.getoption("--nomatlab") 28 | no_theano = config.getoption("--notheano") 29 | no_download = config.getoption("--nodownload") 30 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 31 | skip_nonfree = pytest.mark.skip(reason="need --run-nonfree option to run") 32 | skip_matlab = pytest.mark.skip(reason="skipped because of --nomatlab") 33 | skip_theano = pytest.mark.skip(reason="skipped because of --notheano") 34 | skip_download = pytest.mark.skip(reason="skipped because of --nodownload") 35 | for item in items: 36 | if "slow" in item.keywords and not run_slow: 37 | item.add_marker(skip_slow) 38 | if 'nonfree' in item.keywords and not run_nonfree: 39 | item.add_marker(skip_nonfree) 40 | if "matlab" in item.keywords and no_matlab: 41 | item.add_marker(skip_matlab) 42 | if "theano" in item.keywords and no_theano: 43 | item.add_marker(skip_theano) 44 | if "download" in item.keywords and no_download: 45 | item.add_marker(skip_download) 46 | 47 | 48 | @pytest.fixture(params=["matlab", "octave"]) 49 | def matlab(request, pytestconfig): 50 | import pysaliency.utils 51 | if request.param == "matlab": 52 | pysaliency.utils.MatlabOptions.matlab_names = ['matlab', 'matlab.exe'] 53 | pysaliency.utils.MatlabOptions.octave_names = [] 54 | elif request.param == 'octave': 55 | if pytestconfig.getoption("--nooctave"): 56 | pytest.skip("skipped octave due to command line option") 57 | elif any([marker.name == 'skip_octave' for marker in request.node.own_markers]): 58 | pytest.skip("skipped octave due to test marker") 59 | pysaliency.utils.MatlabOptions.matlab_names = [] 60 | pysaliency.utils.MatlabOptions.octave_names = ['octave', 'octave.exe'] 61 | 62 | return request.param 63 | 64 | #@pytest.fixture(params=["no_location", "with_location"]) 65 | #def location(tmpdir, request): 66 | # if request.param == 'no_location': 67 | # return None 68 | # elif request.param == 'with_location': 69 | # return tmpdir 70 | # else: 71 | # raise ValueError(request.param) 72 | 73 | 74 | # we don't test in memory external datasets anymore 75 | # we'll probably get rid of them anyway 76 | # TODO: remove this fixture, replace with tmpdir 77 | @pytest.fixture() 78 | def location(tmpdir): 79 | return tmpdir 80 | 81 | 82 | @pytest.fixture(autouse=True) 83 | def cache_requests(monkeypatch): 84 | """This fixture caches requests to avoid downloading the same file multiple times. 85 | 86 | TODO: There should be an option to disable this fixture, e.g. when we want to test downloading. 87 | """ 88 | original_get = requests.get 89 | 90 | def mock_get(url, *args, **kwargs): 91 | cache_dir = Path("download_cache") 92 | cache_dir.mkdir(exist_ok=True) 93 | 94 | cache_filename = ( 95 | url.replace("http://", "") 96 | .replace("https://", "") 97 | .replace("/", "_") 98 | .replace("?", "_") 99 | .replace("=", "_") 100 | .replace("&", "_") 101 | .replace(":", "_") 102 | .replace(".", "_") 103 | ) 104 | cache_file = cache_dir / cache_filename 105 | 106 | print("caching", url, "to", cache_file) 107 | 108 | if not cache_file.exists(): 109 | response = original_get(url, *args, **kwargs) 110 | total_size = int(response.headers.get('content-length', 0)) 111 | with open(cache_file, 'wb') as f: 112 | with tqdm(total=total_size, unit='B', unit_scale=True, desc='Downloading file') as progress_bar: 113 | for chunk in response.iter_content(32*1024): 114 | f.write(chunk) 115 | progress_bar.update(len(chunk)) 116 | 117 | with open(cache_file, 'rb') as f: 118 | content = f.read() 119 | mock_response = MagicMock() 120 | mock_response.iter_content = lambda chunk_size: [content[i:i+chunk_size] for i in range(0, len(content), chunk_size)] 121 | mock_response.headers = {'content-length': str(len(content))} 122 | mock_response.status_code = 200 123 | return mock_response 124 | 125 | monkeypatch.setattr(requests, "get", mock_get) 126 | -------------------------------------------------------------------------------- /tests/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from pysaliency.datasets import ScanpathFixations 2 | from pysaliency.datasets.scanpaths import Scanpaths 3 | 4 | 5 | import numpy as np 6 | 7 | from pysaliency.utils.variable_length_array import VariableLengthArray 8 | 9 | 10 | def assert_variable_length_array_equal(array1, array2): 11 | assert isinstance(array1, VariableLengthArray) 12 | assert isinstance(array2, VariableLengthArray) 13 | assert len(array1) == len(array2) 14 | 15 | for i in range(len(array1)): 16 | np.testing.assert_array_equal(array1[i], array2[i], err_msg=f'arrays not equal at index {i}') 17 | 18 | 19 | def assert_scanpaths_equal(scanpaths1: Scanpaths, scanpaths2: Scanpaths, scanpaths2_inds=None): 20 | 21 | if scanpaths2_inds is None: 22 | scanpaths2_inds = slice(None) 23 | 24 | assert isinstance(scanpaths1, Scanpaths) 25 | assert isinstance(scanpaths2, Scanpaths) 26 | 27 | assert_variable_length_array_equal(scanpaths1.xs, scanpaths2.xs[scanpaths2_inds]) 28 | assert_variable_length_array_equal(scanpaths1.ys, scanpaths2.ys[scanpaths2_inds]) 29 | 30 | assert scanpaths1.scanpath_attributes.keys() == scanpaths2.scanpath_attributes.keys() 31 | for attribute_name in scanpaths1.scanpath_attributes.keys(): 32 | np.testing.assert_array_equal(scanpaths1.scanpath_attributes[attribute_name], scanpaths2.scanpath_attributes[attribute_name][scanpaths2_inds]) 33 | 34 | assert scanpaths1.fixation_attributes.keys() == scanpaths2.fixation_attributes.keys() 35 | for attribute_name in scanpaths1.fixation_attributes.keys(): 36 | assert_variable_length_array_equal(scanpaths1.fixation_attributes[attribute_name], scanpaths2.fixation_attributes[attribute_name][scanpaths2_inds]) 37 | 38 | assert scanpaths1.attribute_mapping == scanpaths2.attribute_mapping 39 | 40 | 41 | def compare_fixations_subset(f1, f2, f2_inds): 42 | np.testing.assert_allclose(f1.x, f2.x[f2_inds]) 43 | np.testing.assert_allclose(f1.y, f2.y[f2_inds]) 44 | np.testing.assert_allclose(f1.t, f2.t[f2_inds]) 45 | np.testing.assert_allclose(f1.n, f2.n[f2_inds]) 46 | np.testing.assert_allclose(f1.subject, f2.subject[f2_inds]) 47 | 48 | assert f1.__attributes__ == f2.__attributes__ 49 | for attribute in f1.__attributes__: 50 | if attribute == 'scanpath_index': 51 | continue 52 | np.testing.assert_array_equal(getattr(f1, attribute), getattr(f2, attribute)[f2_inds]) 53 | 54 | 55 | def assert_fixations_equal(f1, f2, crop_length=False): 56 | if crop_length: 57 | maximum_length = np.max(f2.scanpath_history_length) 58 | else: 59 | maximum_length = max(np.max(f1.scanpath_history_length), np.max(f2.scanpath_history_length)) 60 | np.testing.assert_array_equal(f1.x, f2.x) 61 | np.testing.assert_array_equal(f1.y, f2.y) 62 | np.testing.assert_array_equal(f1.t, f2.t) 63 | np.testing.assert_array_equal(f1.n, f2.n) 64 | assert_variable_length_array_equal(f1.x_hist, f2.x_hist) 65 | assert_variable_length_array_equal(f1.y_hist, f2.y_hist) 66 | assert_variable_length_array_equal(f1.t_hist, f2.t_hist) 67 | 68 | f1_attributes = set(f1.__attributes__) 69 | f2_attributes = set(f2.__attributes__) 70 | 71 | assert set(f1_attributes) == set(f2_attributes) 72 | for attribute in f1.__attributes__: 73 | if attribute == 'scanpath_index': 74 | continue 75 | attribute1 = getattr(f1, attribute) 76 | attribute2 = getattr(f2, attribute) 77 | 78 | if isinstance(attribute1, VariableLengthArray): 79 | assert_variable_length_array_equal(attribute1, attribute2) 80 | continue 81 | elif attribute.endswith('_hist'): 82 | attribute1 = attribute1[:, :maximum_length] 83 | attribute2 = attribute2[:, :maximum_length] 84 | 85 | np.testing.assert_array_equal(attribute1, attribute2, err_msg=f'attributes not equal: {attribute}') 86 | 87 | 88 | def assert_fixation_trains_equal(fixation_trains1, fixation_trains2): 89 | assert_variable_length_array_equal(fixation_trains1.train_xs, fixation_trains2.train_xs) 90 | assert_variable_length_array_equal(fixation_trains1.train_ys, fixation_trains2.train_ys) 91 | assert_variable_length_array_equal(fixation_trains1.train_ts, fixation_trains2.train_ts) 92 | 93 | np.testing.assert_array_equal(fixation_trains1.train_ns, fixation_trains2.train_ns) 94 | np.testing.assert_array_equal(fixation_trains1.train_subjects, fixation_trains2.train_subjects) 95 | np.testing.assert_array_equal(fixation_trains1.train_lengths, fixation_trains2.train_lengths) 96 | 97 | assert fixation_trains1.scanpath_attribute_mapping == fixation_trains2.scanpath_attribute_mapping 98 | 99 | assert fixation_trains1.scanpath_attributes.keys() == fixation_trains2.scanpath_attributes.keys() 100 | for attribute_name in fixation_trains1.scanpath_attributes.keys(): 101 | np.testing.assert_array_equal(fixation_trains1.scanpath_attributes[attribute_name], fixation_trains2.scanpath_attributes[attribute_name]) 102 | 103 | assert fixation_trains1.scanpath_fixation_attributes.keys() == fixation_trains2.scanpath_fixation_attributes.keys() 104 | for attribute_name in fixation_trains1.scanpath_fixation_attributes.keys(): 105 | assert_variable_length_array_equal(fixation_trains1.scanpath_fixation_attributes[attribute_name], fixation_trains2.scanpath_fixation_attributes[attribute_name]) 106 | 107 | assert_fixations_equal(fixation_trains1, fixation_trains2) 108 | 109 | 110 | def assert_scanpath_fixations_equal(scanpath_fixations1: ScanpathFixations, scanpath_fixations2: ScanpathFixations): 111 | assert isinstance(scanpath_fixations1, ScanpathFixations) 112 | assert isinstance(scanpath_fixations2, ScanpathFixations) 113 | assert_scanpaths_equal(scanpath_fixations1.scanpaths, scanpath_fixations2.scanpaths) 114 | assert_fixations_equal(scanpath_fixations1, scanpath_fixations2) -------------------------------------------------------------------------------- /tests/external_datasets/test_NUSEF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pytest import approx 4 | from scipy.stats import kurtosis, skew 5 | 6 | import pysaliency 7 | from tests.test_external_datasets import _location, entropy 8 | 9 | 10 | @pytest.mark.slow 11 | @pytest.mark.download 12 | def test_NUSEF(location): 13 | real_location = _location(location) 14 | 15 | stimuli, fixations = pysaliency.external_datasets.get_NUSEF_public(location=real_location) 16 | if location is None: 17 | assert isinstance(stimuli, pysaliency.Stimuli) 18 | assert not isinstance(stimuli, pysaliency.FileStimuli) 19 | else: 20 | assert isinstance(stimuli, pysaliency.FileStimuli) 21 | assert location.join('NUSEF_public/stimuli.hdf5').check() 22 | assert location.join('NUSEF_public/fixations.hdf5').check() 23 | assert location.join('NUSEF_public/src/NUSEF_database.zip').check() 24 | 25 | assert len(stimuli.stimuli) == 429 26 | 27 | assert len(fixations.x) == 66133 28 | 29 | assert np.mean(fixations.x) == approx(461.73823151304873) 30 | assert np.mean(fixations.y) == approx(336.54399742934976) 31 | assert np.mean(fixations.t) == approx(2.0420471776571456) 32 | assert np.mean(fixations.scanpath_history_length) == approx(4.085887529675049) 33 | 34 | assert np.std(fixations.x) == approx(191.71434262715272) 35 | assert np.std(fixations.y) == approx(144.60874197688884) 36 | assert np.std(fixations.t) == approx(1.82140623534086) 37 | assert np.std(fixations.scanpath_history_length) == approx(3.4339653884944963) 38 | 39 | assert kurtosis(fixations.x) == approx(0.29833124844005354) 40 | assert kurtosis(fixations.y) == approx(1.9158192030098018) 41 | assert kurtosis(fixations.t) == approx(5285.812604733467) 42 | assert kurtosis(fixations.scanpath_history_length) == approx(0.8320210638515699) 43 | 44 | assert skew(fixations.x) == approx(0.3994141751115464) 45 | assert skew(fixations.y) == approx(0.7246047287335385) 46 | assert skew(fixations.t) == approx(39.25751334379433) 47 | assert skew(fixations.scanpath_history_length) == approx(0.9874139139443956) 48 | 49 | assert entropy(fixations.n) == approx(8.603204478724775) 50 | assert (fixations.n == 0).sum() == 132 51 | 52 | # not testing this, there are many out-of-stimulus fixations in the dataset 53 | # assert len(fixations) == len(pysaliency.datasets.remove_out_of_stimulus_fixations(stimuli, fixations)) 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /tests/external_datasets/test_PASCAL_S.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from pytest import approx 4 | from scipy.stats import kurtosis, skew 5 | 6 | import pysaliency 7 | import pysaliency.external_datasets 8 | from tests.test_external_datasets import _location, entropy 9 | 10 | 11 | @pytest.mark.slow 12 | @pytest.mark.download 13 | def test_PASCAL_S(location): 14 | real_location = _location(location) 15 | 16 | stimuli, fixations = pysaliency.external_datasets.get_PASCAL_S(location=real_location) 17 | if location is None: 18 | assert isinstance(stimuli, pysaliency.Stimuli) 19 | assert not isinstance(stimuli, pysaliency.FileStimuli) 20 | else: 21 | assert isinstance(stimuli, pysaliency.FileStimuli) 22 | assert location.join('PASCAL-S/stimuli.hdf5').check() 23 | assert location.join('PASCAL-S/fixations.hdf5').check() 24 | 25 | assert len(stimuli.stimuli) == 850 26 | 27 | assert len(fixations.x) == 40314 28 | 29 | assert np.mean(fixations.x) == approx(240.72756362553952) 30 | assert np.mean(fixations.y) == approx(194.85756809048965) 31 | assert np.mean(fixations.t) == approx(2.7856823932132757) 32 | assert np.mean(fixations.scanpath_history_length) == approx(2.7856823932132757) 33 | 34 | assert np.std(fixations.x) == approx(79.57401169717699) 35 | assert np.std(fixations.y) == approx(65.21296890260112) 36 | assert np.std(fixations.t) == approx(2.1191752645988675) 37 | assert np.std(fixations.scanpath_history_length) == approx(2.1191752645988675) 38 | 39 | assert kurtosis(fixations.x) == approx(0.0009226786675387011) 40 | assert kurtosis(fixations.y) == approx(1.1907544566979986) 41 | assert kurtosis(fixations.t) == approx(-0.540943536495714) 42 | assert kurtosis(fixations.scanpath_history_length) == approx(-0.540943536495714) 43 | 44 | assert skew(fixations.x) == approx(0.2112334873314548) 45 | assert skew(fixations.y) == approx(0.7208733522533084) 46 | assert skew(fixations.t) == approx(0.4800678710338635) 47 | assert skew(fixations.scanpath_history_length) == approx(0.4800678710338635) 48 | 49 | assert entropy(fixations.n) == approx(9.711222735065062) 50 | assert (fixations.n == 0).sum() == 35 51 | 52 | assert len(fixations) == len(pysaliency.datasets.remove_out_of_stimulus_fixations(stimuli, fixations)) -------------------------------------------------------------------------------- /tests/external_models/AIM_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/AIM_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/AIM_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/AIM_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/ContextAwareSaliency_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/ContextAwareSaliency_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/ContextAwareSaliency_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/ContextAwareSaliency_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/CovSal_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/CovSal_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/CovSal_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/CovSal_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/GBVSIttiKoch_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/GBVSIttiKoch_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/GBVSIttiKoch_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/GBVSIttiKoch_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/GBVS_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/GBVS_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/GBVS_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/GBVS_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/IttiKoch_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/IttiKoch_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/IttiKoch_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/IttiKoch_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/Judd_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/Judd_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/Judd_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/Judd_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/RARE2007_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/RARE2007_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/RARE2007_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/RARE2007_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/RARE2012_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/RARE2012_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/RARE2012_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/RARE2012_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/SUN_color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/SUN_color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/SUN_grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/SUN_grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/color_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/color_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/grayscale_stimulus.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/external_models/grayscale_stimulus.npy -------------------------------------------------------------------------------- /tests/external_models/test_deepgaze.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import pysaliency 7 | from pysaliency.external_models.deepgaze import DeepGazeI, DeepGazeIIE 8 | 9 | 10 | @pytest.fixture(scope='module') 11 | def color_stimulus(): 12 | return np.load(os.path.join('tests', 'external_models', 'color_stimulus.npy')) 13 | 14 | 15 | @pytest.fixture(scope='module') 16 | def grayscale_stimulus(): 17 | return np.load(os.path.join('tests', 'external_models', 'grayscale_stimulus.npy')) 18 | 19 | 20 | @pytest.fixture 21 | def stimuli(color_stimulus, grayscale_stimulus): 22 | return pysaliency.Stimuli([color_stimulus, grayscale_stimulus]) 23 | 24 | 25 | @pytest.fixture 26 | def fixations(): 27 | return pysaliency.FixationTrains.from_fixation_trains( 28 | [[700, 730], [430, 450]], 29 | [[300, 300], [500, 500]], 30 | [[0, 1], [0, 1]], 31 | ns=[0, 1], 32 | subjects=[0, 0], 33 | ) 34 | 35 | 36 | @pytest.mark.download 37 | def test_deepgaze1(stimuli, fixations): 38 | model = DeepGazeI(centerbias_model=pysaliency.UniformModel(), device='cpu') 39 | 40 | ig = model.information_gain(stimuli, fixations) 41 | 42 | np.testing.assert_allclose(ig, 0.9455161648442227, rtol=5e-6) 43 | 44 | @pytest.mark.download 45 | def test_deepgaze2e(stimuli, fixations): 46 | model = DeepGazeIIE(centerbias_model=pysaliency.UniformModel(), device='cpu') 47 | 48 | ig = model.information_gain(stimuli, fixations) 49 | 50 | np.testing.assert_allclose(ig, 3.918556860669079) -------------------------------------------------------------------------------- /tests/test_baseline_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | from collections import OrderedDict 4 | 5 | import numpy as np 6 | import pytest 7 | 8 | import pysaliency 9 | from pysaliency.baseline_utils import ( 10 | CrossvalMultipleRegularizations, 11 | GeneralMixtureKernelDensityEstimator, 12 | KDEGoldModel, 13 | MixtureKernelDensityEstimator, 14 | ScikitLearnImageCrossValidationGenerator, 15 | fill_fixation_map, 16 | ) 17 | 18 | 19 | @pytest.fixture 20 | def scanpath_fixations(): 21 | xs_trains = [ 22 | [15, 20, 25], 23 | [10, 30], 24 | [30, 20, 10]] 25 | ys_trains = [ 26 | [13, 21, 10], 27 | [15, 35], 28 | [22, 5, 18]] 29 | ts_trains = [ 30 | [0, 200, 600], 31 | [100, 400], 32 | [50, 500, 900]] 33 | ns = [0, 0, 1] 34 | subjects = [0, 1, 1] 35 | return pysaliency.ScanpathFixations(pysaliency.Scanpaths(xs=xs_trains, ys=ys_trains, ts=ts_trains, n=ns, subject=subjects)) 36 | 37 | 38 | @pytest.fixture 39 | def stimuli(): 40 | return pysaliency.Stimuli([np.random.randn(40, 40, 3), 41 | np.random.randn(40, 40, 3)]) 42 | 43 | 44 | def test_fixation_map(): 45 | fixations = np.array([ 46 | [0, 0], 47 | [1, 1], 48 | [1, 1], 49 | [1, 2], 50 | [1, 2], 51 | [2, 1]]) 52 | 53 | fixation_map = np.zeros((3, 3)) 54 | fill_fixation_map(fixation_map, fixations) 55 | 56 | np.testing.assert_allclose(fixation_map, np.array([ 57 | [1, 0, 0], 58 | [0, 2, 2], 59 | [0, 1, 0]])) 60 | 61 | 62 | def test_kde_gold_model(stimuli, scanpath_fixations): 63 | bandwidth = 0.1 64 | kde_gold_model = KDEGoldModel(stimuli, scanpath_fixations, bandwidth=bandwidth) 65 | spaced_kde_gold_model = KDEGoldModel(stimuli, scanpath_fixations, bandwidth=bandwidth, grid_spacing=2) 66 | 67 | full_log_density = kde_gold_model.log_density(stimuli[0]) 68 | spaced_log_density = spaced_kde_gold_model.log_density(stimuli[0]) 69 | 70 | kl_div1 = np.sum(np.exp(full_log_density) * (full_log_density - spaced_log_density)) / np.log(2) 71 | kl_div2 = np.sum(np.exp(spaced_log_density) * (spaced_log_density - full_log_density)) / np.log(2) 72 | 73 | assert kl_div1 < 0.002 74 | assert kl_div2 < 0.002 75 | 76 | full_ll = kde_gold_model.information_gain(stimuli, scanpath_fixations, average='image') 77 | spaced_ll = spaced_kde_gold_model.information_gain(stimuli, scanpath_fixations, average='image') 78 | print(full_ll, spaced_ll) 79 | np.testing.assert_allclose(full_ll, 2.1912009255501252) 80 | np.testing.assert_allclose(spaced_ll, 2.191055750664578) 81 | 82 | 83 | def test_general_mixture_kernel_density_estimator(): 84 | # Test initialization 85 | estimator = GeneralMixtureKernelDensityEstimator(bandwidth=1.0, regularizations=[0.2, 0.1], regularizing_log_likelihoods=[[-1, 0.0], [-0.1, -10.0], [-10, -0.1]]) 86 | assert estimator.bandwidth == 1.0 87 | assert np.allclose(estimator.regularizations, [0.2, 0.1]) 88 | assert np.allclose(estimator.regularizing_log_likelihoods, [[-1, 0.0], [-0.1, -10.0], [-10, -0.1]]) 89 | 90 | # Test setup 91 | estimator.setup() 92 | assert estimator.kde is not None 93 | assert estimator.kde_constant is not None 94 | assert estimator.regularization_constants is not None 95 | 96 | # Test fit 97 | X = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) 98 | estimator.fit(X) 99 | assert estimator.kde is not None 100 | 101 | # Test score_samples 102 | X = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) 103 | logliks = estimator.score_samples(X) 104 | assert logliks.shape == (3,) 105 | np.testing.assert_allclose(logliks, [-1.49141561, -1.40473767, -1.95213405]) 106 | 107 | # Test score 108 | X = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) 109 | score = estimator.score(X) 110 | assert isinstance(score, float) 111 | 112 | 113 | def test_mixture_kernel_density_estimator(): 114 | # Test initialization 115 | estimator = MixtureKernelDensityEstimator(bandwidth=1.0, regularization=1.0e-5, regularizing_log_likelihoods=[-0.3, -0.2, -0.1]) 116 | assert estimator.bandwidth == 1.0 117 | assert estimator.regularization == 1.0e-5 118 | 119 | # Test setup 120 | estimator.setup() 121 | assert estimator.kde is not None 122 | assert estimator.kde_constant is not None 123 | assert estimator.uniform_constant is not None 124 | 125 | # Test fit 126 | X = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) 127 | estimator.fit(X) 128 | assert estimator.kde is not None 129 | 130 | # Test score_samples 131 | X = np.array([[0, 0.2, 0], [0.3, 1, 1], [1, 1, 2]]) 132 | logliks = estimator.score_samples(X) 133 | assert logliks.shape == (3,) 134 | np.testing.assert_allclose(logliks, [-2.56662505, -2.5272495, -2.38495638]) 135 | 136 | # Test score 137 | X = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2]]) 138 | score = estimator.score(X) 139 | assert isinstance(score, float) 140 | 141 | 142 | def test_crossval_multiple_regularizations(stimuli, scanpath_fixations): 143 | # Test initialization 144 | regularization_models = OrderedDict([('model1', pysaliency.UniformModel()), ('model2', pysaliency.models.GaussianModel())]) 145 | crossvalidation = ScikitLearnImageCrossValidationGenerator(stimuli, scanpath_fixations) 146 | estimator = CrossvalMultipleRegularizations(stimuli, scanpath_fixations, regularization_models, crossvalidation) 147 | assert estimator.cv is crossvalidation 148 | assert estimator.mean_area is not None 149 | assert estimator.X is not None 150 | assert estimator.regularization_log_likelihoods is not None 151 | 152 | # Test score 153 | log_bandwidth = 0.1 154 | log_regularizations = [0.1, 0.2] 155 | 156 | score = estimator.score(log_bandwidth, *log_regularizations) 157 | assert isinstance(score, float) 158 | np.testing.assert_allclose(score, -1.4673831679692528e-10) -------------------------------------------------------------------------------- /tests/test_crossvalidation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import numpy as np 4 | import pytest 5 | from sklearn.model_selection import cross_val_score 6 | 7 | import pysaliency 8 | from pysaliency.baseline_utils import RegularizedKernelDensityEstimator, ScikitLearnImageCrossValidationGenerator, ScikitLearnImageSubjectCrossValidationGenerator, fixations_to_scikit_learn 9 | 10 | 11 | class ConstantSaliencyModel(pysaliency.Model): 12 | def _log_density(self, stimulus): 13 | return np.zeros((stimulus.shape[0], stimulus.shape[1])) - np.log(stimulus.shape[0]) - np.log(stimulus.shape[1]) 14 | 15 | 16 | class GaussianSaliencyModel(pysaliency.Model): 17 | def _log_density(self, stimulus): 18 | height = stimulus.shape[0] 19 | width = stimulus.shape[1] 20 | YS, XS = np.mgrid[:height, :width] 21 | r_squared = (XS-0.5*width)**2 + (YS-0.5*height)**2 22 | size = np.sqrt(width**2+height**2) 23 | values = np.ones((stimulus.shape[0], stimulus.shape[1]))*np.exp(-0.5*(r_squared/size)) 24 | density = values / values.sum() 25 | return np.log(density) 26 | 27 | 28 | @pytest.fixture 29 | def scanpath_fixations(): 30 | xs_trains = [ 31 | [0, 1, 2], 32 | [2, 2], 33 | [1, 5, 3], 34 | [10]] 35 | ys_trains = [ 36 | [10, 11, 12], 37 | [12, 12], 38 | [21, 25, 33], 39 | [11]] 40 | ts_trains = [ 41 | [0, 200, 600], 42 | [100, 400], 43 | [50, 500, 900], 44 | [100]] 45 | ns = [0, 0, 1, 2] 46 | subjects = [0, 1, 1, 0] 47 | return pysaliency.ScanpathFixations(pysaliency.Scanpaths(xs=xs_trains, ys=ys_trains, ts=ts_trains, n=ns, subject=subjects)) 48 | 49 | 50 | @pytest.fixture 51 | def stimuli(): 52 | return pysaliency.Stimuli([np.random.randn(40, 40, 3), 53 | np.random.randn(40, 40, 3), 54 | np.random.randn(40, 40, 3)]) 55 | 56 | 57 | def _unpack_crossval(cv): 58 | for train_inds, test_inds in cv: 59 | yield list(train_inds), list(test_inds) 60 | 61 | 62 | def unpack_crossval(cv): 63 | return list(_unpack_crossval(cv)) 64 | 65 | 66 | def test_image_crossvalidation(stimuli, scanpath_fixations): 67 | gsmm = GaussianSaliencyModel() 68 | 69 | cv = ScikitLearnImageCrossValidationGenerator(stimuli, scanpath_fixations) 70 | 71 | assert unpack_crossval(cv) == [ 72 | ([False, False, False, False, False, True, True, True, True], 73 | [True, True, True, True, True, False, False, False, False]), 74 | ([True, True, True, True, True, False, False, False, True], 75 | [False, False, False, False, False, True, True, True, False]), 76 | ([True, True, True, True, True, True, True, True, False], 77 | [False, False, False, False, False, False, False, False, True]) 78 | ] 79 | 80 | X = fixations_to_scikit_learn(scanpath_fixations, normalize=stimuli, add_shape=True) 81 | 82 | assert cross_val_score( 83 | RegularizedKernelDensityEstimator(bandwidth=0.1, regularization=0.1), 84 | X, 85 | cv=cv, 86 | verbose=0).sum() 87 | 88 | 89 | def test_image_subject_crossvalidation(stimuli, scanpath_fixations): 90 | gsmm = GaussianSaliencyModel() 91 | 92 | cv = ScikitLearnImageSubjectCrossValidationGenerator(stimuli, scanpath_fixations) 93 | 94 | assert unpack_crossval(cv) == [ 95 | ([3, 4], [0, 1, 2]), 96 | ([0, 1, 2], [3, 4]) 97 | ] 98 | 99 | X = fixations_to_scikit_learn(scanpath_fixations, normalize=stimuli, add_shape=True) 100 | 101 | assert cross_val_score( 102 | RegularizedKernelDensityEstimator(bandwidth=0.1, regularization=0.1), 103 | X, 104 | cv=cv, 105 | verbose=0).sum() 106 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division 2 | 3 | import unittest 4 | import pickle 5 | import os.path 6 | import shutil 7 | import filecmp 8 | 9 | def assert_equal(a, b): 10 | assert a == b 11 | 12 | 13 | class TestWithData(unittest.TestCase): 14 | data_path = 'test_data' 15 | 16 | def setUp(self): 17 | if os.path.isdir(self.data_path): 18 | shutil.rmtree(self.data_path) 19 | if not os.path.exists(self.data_path): 20 | os.makedirs(self.data_path) 21 | 22 | def tearDown(self): 23 | shutil.rmtree(self.data_path) 24 | 25 | def pickle_and_reload(self, data, pickler = pickle): 26 | filename = os.path.join(self.data_path, 'object.pydat') 27 | 28 | with open(filename, 'wb') as f: 29 | pickler.dump(data, f) 30 | 31 | with open(filename, 'rb') as f: 32 | new_data = pickler.load(f) 33 | 34 | return new_data 35 | 36 | 37 | def check_dircmp(dircmp): 38 | assert_equal(dircmp.left_only, []) 39 | assert_equal(dircmp.right_only, []) 40 | assert_equal(dircmp.diff_files, []) 41 | assert_equal(dircmp.funny_files, []) 42 | for sub_dcmp in dircmp.subdirs.values(): 43 | check_dircmp(dircmp) 44 | 45 | 46 | def assertDirsEqual(dir1, dir2, ignore=[]): 47 | dircmp = filecmp.dircmp(dir1, dir2, ignore=ignore) 48 | check_dircmp(dircmp) 49 | -------------------------------------------------------------------------------- /tests/test_metric_optimization.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | import pytest 4 | import numpy as np 5 | 6 | import pysaliency 7 | 8 | 9 | class GaussianSaliencyModel(pysaliency.Model): 10 | def _log_density(self, stimulus): 11 | height = stimulus.shape[0] 12 | width = stimulus.shape[1] 13 | YS, XS = np.mgrid[:height, :width] 14 | r_squared = (XS-0.5*width)**2 + (YS-0.5*height)**2 15 | size = np.sqrt(width**2+height**2) 16 | values = np.ones((stimulus.shape[0], stimulus.shape[1]))*np.exp(-0.5*(r_squared/size)) 17 | density = values / values.sum() 18 | return np.log(density) 19 | 20 | 21 | @pytest.fixture 22 | def stimuli(): 23 | return pysaliency.Stimuli([np.random.randn(40, 40, 3), 24 | np.random.randn(40, 40, 3)]) 25 | 26 | 27 | def test_sim_saliency_map(stimuli): 28 | gsmm = GaussianSaliencyModel() 29 | 30 | sim_model = pysaliency.SIMSaliencyMapModel(gsmm, kernel_size=2, max_iter=100, initial_learning_rate=1e-6, 31 | learning_rate_decay_scheme='validation_loss') 32 | 33 | smap = sim_model.saliency_map(stimuli[0]) 34 | assert smap.shape == (40, 40) 35 | -------------------------------------------------------------------------------- /tests/test_metric_optimization_tf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | 5 | # from pysaliency.metric_optimization_tf import maximize_expected_sim 6 | 7 | 8 | @pytest.mark.skip("tensorflow <2.0 not available for new python versions, need to upgrade to tensorflow 2 in pysaliency") 9 | def test_maximize_expected_sim_decay_1overk(): 10 | density = np.ones((20, 20)) 11 | density[6:17, 8:12] = 20 12 | density[2:4, 18:18] = 30 13 | density /= density.sum() 14 | log_density = np.log(density) 15 | 16 | saliency_map, score = maximize_expected_sim( 17 | log_density=log_density, 18 | kernel_size=1, 19 | train_samples_per_epoch=1000, 20 | val_samples=1000, 21 | max_iter=100 22 | ) 23 | 24 | np.testing.assert_allclose(score, -0.8202789932489393, rtol=5e-7) # need bigger tolerance to handle differences between CPU and GPU 25 | 26 | 27 | @pytest.mark.skip("tensorflow <2.0 not available for new python versions, need to upgrade to tensorflow 2 in pysaliency") 28 | def test_maximize_expected_sim_decay_on_plateau(): 29 | density = np.ones((20, 20)) 30 | density[6:17, 8:12] = 20 31 | density[2:4, 18:18] = 30 32 | density /= density.sum() 33 | log_density = np.log(density) 34 | 35 | saliency_map, score = maximize_expected_sim( 36 | log_density=log_density, 37 | kernel_size=1, 38 | train_samples_per_epoch=1000, 39 | val_samples=1000, 40 | max_iter=100, 41 | backlook=1, 42 | min_iter=10, 43 | learning_rate_decay_scheme='validation_loss', 44 | ) 45 | 46 | np.testing.assert_allclose(score, -0.8203513294458387, rtol=5e-7) # need bigger tolerance to handle differences between CPU and GPU 47 | -------------------------------------------------------------------------------- /tests/test_metric_optimization_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pysaliency.metric_optimization_torch import maximize_expected_sim 4 | 5 | 6 | def test_maximize_expected_sim_decay_1overk(): 7 | density = np.ones((20, 20)) 8 | density[6:17, 8:12] = 20 9 | density[2:4, 18:18] = 30 10 | density /= density.sum() 11 | log_density = np.log(density) 12 | 13 | saliency_map, score = maximize_expected_sim( 14 | log_density=log_density, 15 | kernel_size=1, 16 | train_samples_per_epoch=1000, 17 | val_samples=1000, 18 | max_iter=100 19 | ) 20 | 21 | print(score) 22 | # We need a quite big tolerance in this test. Apparently there are 23 | # substantial differences between different systems, I'm not sure why. 24 | np.testing.assert_allclose(score, -0.8204902112483976, rtol=5e-4) 25 | 26 | 27 | def test_maximize_expected_sim_decay_on_plateau(): 28 | density = np.ones((20, 20)) 29 | density[6:17, 8:12] = 20 30 | density[2:4, 18:18] = 30 31 | density /= density.sum() 32 | log_density = np.log(density) 33 | 34 | saliency_map, score = maximize_expected_sim( 35 | log_density=log_density, 36 | kernel_size=1, 37 | train_samples_per_epoch=1000, 38 | val_samples=1000, 39 | max_iter=100, 40 | backlook=1, 41 | min_iter=10, 42 | learning_rate_decay_scheme='validation_loss', 43 | ) 44 | 45 | print(score) 46 | np.testing.assert_allclose(score, -0.8205618500709532, rtol=5e-4) # need bigger tolerance to handle differences between CPU and GPU 47 | -------------------------------------------------------------------------------- /tests/test_numba_utils.py: -------------------------------------------------------------------------------- 1 | from hypothesis import given, strategies as st, assume, settings 2 | import numpy as np 3 | 4 | from pysaliency.numba_utils import auc_for_one_positive, general_roc_numba, general_rocs_per_positive_numba 5 | from pysaliency.roc_cython import general_roc, general_rocs_per_positive 6 | 7 | 8 | def test_auc_for_one_positive(): 9 | assert auc_for_one_positive(1, [0, 2]) == 0.5 10 | assert auc_for_one_positive(1, [1]) == 0.5 11 | assert auc_for_one_positive(3, [0]) == 1.0 12 | assert auc_for_one_positive(0, [3]) == 0.0 13 | 14 | 15 | @given(st.lists(st.floats(allow_nan=False, allow_infinity=False), min_size=1), st.floats(allow_nan=False, allow_infinity=False)) 16 | def test_simple_auc_hypothesis(negatives, positive): 17 | old_auc, _, _ = general_roc(np.array([positive]), np.array(negatives)) 18 | new_auc = auc_for_one_positive(positive, np.array(negatives)) 19 | np.testing.assert_allclose(old_auc, new_auc) 20 | 21 | 22 | @settings(deadline=None) #to remove time limit from a test 23 | @given(st.lists(st.floats(allow_infinity=False,allow_nan=False),min_size=1), st.lists(st.floats(allow_infinity=False,allow_nan=False),min_size=1)) 24 | def test_numba_auc_test1(positives,negatives): 25 | positives = np.array(positives) 26 | negatives = np.array(negatives) 27 | numba_output = general_roc_numba(positives,negatives) 28 | cython_output = general_roc(positives,negatives) 29 | assert np.isclose(numba_output[0],cython_output[0]) 30 | assert (numba_output[1] == cython_output[1]).all() 31 | assert (numba_output[2] == cython_output[2]).all() 32 | 33 | 34 | @settings(deadline=None) 35 | @given(st.lists(st.floats(allow_infinity=False,allow_nan=False),min_size=1), st.floats(allow_infinity=False,allow_nan=False)) 36 | def test_numba_auc_test2(positives,temp_variable): 37 | positives = np.array(positives) 38 | negatives = positives+temp_variable 39 | numba_output = general_roc_numba(positives,negatives) 40 | cython_output = general_roc(positives,negatives) 41 | assert np.isclose(numba_output[0],cython_output[0]) 42 | assert (numba_output[1] == cython_output[1]).all() 43 | assert (numba_output[2] == cython_output[2]).all() 44 | 45 | 46 | @settings(deadline=None) 47 | @given(st.lists(st.floats(allow_infinity=False,allow_nan=False),min_size=1), st.lists(st.floats(allow_infinity=False,allow_nan=False),min_size=1)) 48 | def test_numba_rocs_per_positive(positives,negatives): 49 | positives = np.array(positives) 50 | negatives = np.array(negatives) 51 | numba_output = general_rocs_per_positive_numba(positives,negatives) 52 | cython_output = general_rocs_per_positive(positives,negatives) 53 | assert (numba_output == cython_output).all() -------------------------------------------------------------------------------- /tests/test_quilt.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, unicode_literals, absolute_import 2 | 3 | import os 4 | import shutil 5 | import filecmp 6 | import unittest 7 | 8 | from pysaliency.quilt import PatchFile, QuiltSeries 9 | 10 | from test_helpers import TestWithData, assertDirsEqual 11 | 12 | 13 | class TestPatchFile(TestWithData): 14 | def test_parsing(self): 15 | p = open('tests/test_quilt/patches/add_numbers.diff').read() 16 | patch = PatchFile(p) 17 | self.assertEqual(len(patch.diffs), 1) 18 | diff = patch.diffs[0] 19 | self.assertEqual(len(diff.hunks), 1) 20 | self.assertEqual(diff.source_filename, 'source.txt') 21 | self.assertEqual(diff.target_filename, 'source.txt') 22 | 23 | hunk = diff.hunks[0] 24 | self.assertEqual(hunk.source_start, 3) 25 | self.assertEqual(hunk.source_length, 6) 26 | 27 | self.assertEqual(hunk.target_start, 3) 28 | self.assertEqual(hunk.target_length, 8) 29 | 30 | def test_apply(self): 31 | location = os.path.join(self.data_path, 'patching') 32 | shutil.copytree('tests/test_quilt/source', location) 33 | p = open('tests/test_quilt/patches/add_numbers.diff').read() 34 | patch = PatchFile(p) 35 | 36 | patch.apply(location) 37 | self.assertTrue(filecmp.cmp(os.path.join(location, 'source.txt'), 38 | 'tests/test_quilt/target/source.txt', 39 | shallow=False)) 40 | 41 | 42 | class TestSeries(TestWithData): 43 | def test_parsing(self): 44 | series = QuiltSeries(os.path.join('tests', 'test_quilt', 'patches')) 45 | self.assertEqual(len(series.patches), 1) 46 | 47 | def test_apply(self): 48 | location = os.path.join(self.data_path, 'patching') 49 | shutil.copytree('tests/test_quilt/source', location) 50 | series = QuiltSeries(os.path.join('tests', 'test_quilt', 'patches')) 51 | series.apply(location) 52 | assertDirsEqual(location, os.path.join('tests', 'test_quilt', 'target'), 53 | ignore=['patches', '.pc']) 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /tests/test_quilt/.pc/.quilt_patches: -------------------------------------------------------------------------------- 1 | patches 2 | -------------------------------------------------------------------------------- /tests/test_quilt/.pc/.quilt_series: -------------------------------------------------------------------------------- 1 | series 2 | -------------------------------------------------------------------------------- /tests/test_quilt/.pc/.version: -------------------------------------------------------------------------------- 1 | 2 2 | -------------------------------------------------------------------------------- /tests/test_quilt/.pc/add_numbers.diff/.timestamp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthias-k/pysaliency/3f153e49bb636f774b9f49ccb5016b22f70eb52d/tests/test_quilt/.pc/add_numbers.diff/.timestamp -------------------------------------------------------------------------------- /tests/test_quilt/.pc/add_numbers.diff/source.txt: -------------------------------------------------------------------------------- 1 | foo 2 | bar 3 | baz 4 | abcefg 5 | blub 6 | 3.14156 7 | 2.718281828459045 8 | 0.7071 9 | 1.4142135623730951 10 | -------------------------------------------------------------------------------- /tests/test_quilt/.pc/applied-patches: -------------------------------------------------------------------------------- 1 | add_numbers.diff 2 | -------------------------------------------------------------------------------- /tests/test_quilt/patches/add_numbers.diff: -------------------------------------------------------------------------------- 1 | Index: test_quilt/source.txt 2 | =================================================================== 3 | --- test_quilt.orig/source.txt 2014-12-04 23:15:40.372473852 +0100 4 | +++ test_quilt/source.txt 2014-12-04 23:16:04.820103862 +0100 5 | @@ -3,6 +3,8 @@ 6 | baz 7 | abcefg 8 | blub 9 | +42 10 | +23 11 | 3.14156 12 | 2.718281828459045 13 | 0.7071 14 | -------------------------------------------------------------------------------- /tests/test_quilt/patches/series: -------------------------------------------------------------------------------- 1 | add_numbers.diff 2 | -------------------------------------------------------------------------------- /tests/test_quilt/source.txt: -------------------------------------------------------------------------------- 1 | foo 2 | bar 3 | baz 4 | abcefg 5 | blub 6 | 42 7 | 23 8 | 3.14156 9 | 2.718281828459045 10 | 0.7071 11 | 1.4142135623730951 12 | -------------------------------------------------------------------------------- /tests/test_quilt/source/.pc/.quilt_patches: -------------------------------------------------------------------------------- 1 | patches 2 | -------------------------------------------------------------------------------- /tests/test_quilt/source/.pc/.quilt_series: -------------------------------------------------------------------------------- 1 | series 2 | -------------------------------------------------------------------------------- /tests/test_quilt/source/.pc/.version: -------------------------------------------------------------------------------- 1 | 2 2 | -------------------------------------------------------------------------------- /tests/test_quilt/source/patches: -------------------------------------------------------------------------------- 1 | ../patches -------------------------------------------------------------------------------- /tests/test_quilt/source/source.txt: -------------------------------------------------------------------------------- 1 | foo 2 | bar 3 | baz 4 | abcefg 5 | blub 6 | 3.14156 7 | 2.718281828459045 8 | 0.7071 9 | 1.4142135623730951 10 | -------------------------------------------------------------------------------- /tests/test_quilt/target/source.txt: -------------------------------------------------------------------------------- 1 | foo 2 | bar 3 | baz 4 | abcefg 5 | blub 6 | 42 7 | 23 8 | 3.14156 9 | 2.718281828459045 10 | 0.7071 11 | 1.4142135623730951 12 | -------------------------------------------------------------------------------- /tests/test_saliency_map_conversion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import pysaliency 5 | from pysaliency import optimize_for_information_gain 6 | from pysaliency.models import SaliencyMapNormalizingModel 7 | 8 | 9 | @pytest.fixture 10 | def stimuli(): 11 | return pysaliency.Stimuli([np.random.randint(0, 255, size=(25, 30, 3)) for i in range(50)]) 12 | 13 | 14 | @pytest.fixture 15 | def saliency_model(): 16 | return pysaliency.GaussianSaliencyMapModel(center_x=0.15, center_y=0.85, width=0.2) 17 | 18 | 19 | @pytest.fixture 20 | def transformed_saliency_model(saliency_model): 21 | return pysaliency.saliency_map_models.LambdaSaliencyMapModel( 22 | [saliency_model], 23 | fn=lambda smaps: np.sqrt(smaps[0]), 24 | ) 25 | 26 | 27 | @pytest.fixture 28 | def probabilistic_model(saliency_model): 29 | blurred_model = pysaliency.BluringSaliencyMapModel(saliency_model, kernel_size=5.0) 30 | centerbias_model = pysaliency.saliency_map_models.LambdaSaliencyMapModel( 31 | [pysaliency.GaussianSaliencyMapModel(width=0.5)], 32 | fn=lambda smaps: 1.0 * smaps[0], 33 | ) 34 | model_with_centerbias = blurred_model * centerbias_model 35 | probabilistic_model = SaliencyMapNormalizingModel(model_with_centerbias) 36 | 37 | return probabilistic_model 38 | 39 | 40 | @pytest.fixture 41 | def fixations(stimuli, probabilistic_model): 42 | return probabilistic_model.sample(stimuli, 1000, rst=np.random.RandomState(seed=42)) 43 | 44 | 45 | @pytest.fixture(params=["torch", "theano"]) 46 | def framework(request): 47 | 48 | if request.param == 'theano': 49 | import theano 50 | old_optimizer = theano.config.optimizer 51 | theano.config.optimizer = 'fast_compile' 52 | 53 | yield request.param 54 | 55 | if request.param == 'theano': 56 | theano.config.optimize = old_optimizer 57 | 58 | 59 | def test_optimize_for_information_gain(stimuli, fixations, transformed_saliency_model, probabilistic_model, framework): 60 | expected_information_gain = probabilistic_model.information_gain(stimuli, fixations, average='image') 61 | 62 | model1, ret1 = optimize_for_information_gain( 63 | transformed_saliency_model, 64 | stimuli, 65 | fixations, 66 | average='fixations', 67 | verbose=2, 68 | batch_size=1 if framework == 'theano' else 10, 69 | minimize_options={'verbose': 10} if framework == 'torch' else None, 70 | maxiter=50, 71 | blur_radius=2.0, 72 | return_optimization_result=True, 73 | framework=framework, 74 | ) 75 | 76 | reached_information_gain = model1.information_gain(stimuli, fixations, average='image') 77 | 78 | print(expected_information_gain, reached_information_gain) 79 | assert reached_information_gain >= expected_information_gain - 0.01 80 | -------------------------------------------------------------------------------- /tests/test_saliency_map_conversion_theano.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import dill 3 | import pytest 4 | 5 | from pysaliency import GaussianSaliencyMapModel, Stimuli, Fixations 6 | from pysaliency.saliency_map_conversion_theano import SaliencyMapConvertor, optimize_for_information_gain 7 | 8 | 9 | @pytest.mark.theano 10 | @pytest.mark.parametrize("optimize", [ 11 | None, 12 | ['nonlinearity'], 13 | ['nonlinearity', 'centerbias'], 14 | ['nonlinearity', 'alpha', 'centerbias'], 15 | ['centerbias'], 16 | ['blur_radius'], 17 | ['blur_radius', 'nonlinearity'] 18 | ]) 19 | def test_optimize_for_IG(optimize): 20 | # To speed up testing, we disable some optimizations 21 | import theano 22 | old_optimizer = theano.config.optimizer 23 | theano.config.optimizer = 'fast_compile' 24 | 25 | model = GaussianSaliencyMapModel() 26 | stimulus = np.random.randn(100, 100, 3) 27 | stimuli = Stimuli([stimulus]) 28 | 29 | rst = np.random.RandomState(seed=42) 30 | N = 100000 31 | fixations = Fixations.create_without_history( 32 | x=rst.rand(N) * 100, 33 | y=rst.rand(N) * 100, 34 | n=np.zeros(N, dtype=int) 35 | ) 36 | 37 | smc, res = optimize_for_information_gain( 38 | model, 39 | stimuli, 40 | fixations, 41 | optimize=optimize, 42 | blur_radius=3, 43 | verbose=2, 44 | maxiter=10, 45 | return_optimization_result=True) 46 | 47 | theano.config.optimizer = old_optimizer 48 | 49 | assert res.status in [ 50 | 0, # success 51 | 9, # max iter reached 52 | ] 53 | 54 | assert smc 55 | 56 | 57 | @pytest.mark.theano 58 | def test_saliency_map_converter(tmpdir): 59 | import theano 60 | theano.config.floatX = 'float64' 61 | old_optimizer = theano.config.optimizer 62 | theano.config.optimizer = 'fast_compile' 63 | 64 | model = GaussianSaliencyMapModel() 65 | smc = SaliencyMapConvertor(model) 66 | smc.set_params(nonlinearity=np.ones(20), 67 | centerbias=np.ones(12) * 2, 68 | alpha=3, 69 | blur_radius=4, 70 | saliency_min=5, 71 | saliency_max=6) 72 | 73 | theano.config.optimizer = old_optimizer 74 | 75 | pickle_file = tmpdir.join('object.pydat') 76 | with pickle_file.open(mode='wb') as f: 77 | dill.dump(smc, f) 78 | 79 | with pickle_file.open(mode='rb') as f: 80 | smc2 = dill.load(f) 81 | 82 | np.testing.assert_allclose(smc2.saliency_map_processing.nonlinearity_ys.get_value(), np.ones(20)) 83 | np.testing.assert_allclose(smc2.saliency_map_processing.centerbias_ys.get_value(), np.ones(12) * 2) 84 | np.testing.assert_allclose(smc2.saliency_map_processing.alpha.get_value(), 3) 85 | np.testing.assert_allclose(smc2.saliency_map_processing.blur_radius.get_value(), 4) 86 | np.testing.assert_allclose(smc2.saliency_min, 5) 87 | np.testing.assert_allclose(smc2.saliency_max, 6) 88 | -------------------------------------------------------------------------------- /tests/test_saliency_map_conversion_torch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pysaliency.saliency_map_conversion import optimize_for_information_gain 5 | from pysaliency import Stimuli, Fixations, GaussianSaliencyMapModel 6 | 7 | 8 | @pytest.mark.parametrize("optimize", [ 9 | None, 10 | ['nonlinearity'], 11 | ['nonlinearity', 'centerbias'], 12 | ['nonlinearity', 'alpha', 'centerbias'], 13 | ['centerbias'], 14 | ['blur_radius'], 15 | ['blur_radius', 'nonlinearity'] 16 | ]) 17 | def test_optimize_for_IG(optimize): 18 | model = GaussianSaliencyMapModel() 19 | stimulus = np.random.randn(100, 100, 3) 20 | stimuli = Stimuli([stimulus]) 21 | 22 | rst = np.random.RandomState(seed=42) 23 | N = 100000 24 | fixations = Fixations.create_without_history( 25 | x=rst.rand(N) * 100, 26 | y=rst.rand(N) * 100, 27 | n=np.zeros(N, dtype=int) 28 | ) 29 | 30 | smc, res = optimize_for_information_gain( 31 | model, 32 | stimuli, 33 | fixations, 34 | optimize=optimize, 35 | blur_radius=3, 36 | verbose=2, 37 | maxiter=10, 38 | return_optimization_result=True, 39 | framework='torch' 40 | ) 41 | 42 | assert res.status in [ 43 | 0, # success 44 | 9, # max iter reached 45 | ] 46 | 47 | assert smc 48 | -------------------------------------------------------------------------------- /tests/test_sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pysaliency.saliency_map_models import GaussianSaliencyMapModel 4 | from pysaliency.sampling_models import SamplingModelMixin, ScanpathSamplingModelMixin 5 | 6 | 7 | def test_fixation_sampling(): 8 | class SamplingModel(SamplingModelMixin, GaussianSaliencyMapModel): 9 | def sample_fixation(self, stimulus, x_hist, y_hist, t_hist, attributes=None, verbose=False, rst=None): 10 | return x_hist[-1] + 1, y_hist[-1] + 1, t_hist[-1] + 1 11 | 12 | model = SamplingModel() 13 | 14 | xs, ys, ts = model.sample_scanpath(np.zeros((40, 40, 3)), [0], [1], [2], 4) 15 | assert xs == [0, 1, 2, 3, 4] 16 | assert ys == [1, 2, 3, 4, 5] 17 | assert ts == [2, 3, 4, 5, 6] 18 | 19 | 20 | def test_scanpath_sampling(): 21 | class SamplingModel(ScanpathSamplingModelMixin, GaussianSaliencyMapModel): 22 | def sample_scanpath(self, stimulus, x_hist, y_hist, t_hist, samples, attributes=None, verbose=False, rst=None): 23 | return ( 24 | list(x_hist) + [x_hist[-1]] * samples, 25 | list(y_hist) + [y_hist[-1]] * samples, 26 | list(t_hist) + [t_hist[-1]] * samples 27 | ) 28 | 29 | model = SamplingModel() 30 | 31 | x, y, t = model.sample_fixation(np.zeros((40, 40, 3)), [0], [1], [2]) 32 | assert x == 0 33 | assert y == 1 34 | assert t == 2 35 | -------------------------------------------------------------------------------- /tests/test_torch_datasets.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import pytest 6 | 7 | from pysaliency import ( 8 | FileStimuli, 9 | GaussianSaliencyMapModel, 10 | DigitizeMapModel, 11 | SaliencyMapModelFromDirectory, 12 | UniformModel 13 | ) 14 | from pysaliency.torch_datasets import ImageDataset, ImageDatasetSampler, FixationMaskTransform, collate_fn 15 | import torch 16 | 17 | 18 | @pytest.fixture 19 | def stimuli(tmp_path): 20 | filenames = [] 21 | stimuli_directory = tmp_path / 'stimuli' 22 | stimuli_directory.mkdir() 23 | for i in range(50): 24 | image = Image.fromarray(np.random.randint(0, 255, size=(25, 30, 3), dtype=np.uint8)) 25 | filename = stimuli_directory / 'stimulus_{:04d}.png'.format(i) 26 | image.save(filename) 27 | filenames.append(filename) 28 | return FileStimuli(filenames) 29 | 30 | 31 | @pytest.fixture 32 | def fixations(stimuli): 33 | return UniformModel().sample(stimuli, 1000, rst=np.random.RandomState(seed=42)) 34 | 35 | 36 | @pytest.fixture 37 | def saliency_model(): 38 | return GaussianSaliencyMapModel(center_x=0.15, center_y=0.85, width=0.2) 39 | 40 | 41 | @pytest.fixture 42 | def png_saliency_map_model(tmp_path, stimuli, saliency_model): 43 | digitized_model = DigitizeMapModel(saliency_model) 44 | output_path = tmp_path / 'saliency_maps' 45 | output_path.mkdir() 46 | 47 | for filename, stimulus in zip(stimuli.filenames, stimuli): 48 | stimulus_name = Path(filename) 49 | output_filename = output_path / f"{stimulus_name.stem}.png" 50 | image = Image.fromarray(digitized_model.saliency_map(stimulus).astype(np.uint8)) 51 | image.save(output_filename) 52 | 53 | return SaliencyMapModelFromDirectory(stimuli, str(output_path)) 54 | 55 | 56 | def test_dataset(stimuli, fixations, png_saliency_map_model): 57 | models_dict = { 58 | 'saliency_map': png_saliency_map_model, 59 | } 60 | 61 | dataset = ImageDataset( 62 | stimuli, 63 | fixations, 64 | models=models_dict, 65 | transform=FixationMaskTransform(), 66 | average='image', 67 | ) 68 | 69 | loader = torch.utils.data.DataLoader( 70 | dataset, 71 | batch_sampler=ImageDatasetSampler(dataset, batch_size=4, shuffle=False), 72 | pin_memory=False, 73 | num_workers=0, # doesn't work for sparse tensors yet. Might work soon. 74 | collate_fn=collate_fn, 75 | ) 76 | 77 | count = 0 78 | for batch in loader: 79 | count += len(batch['saliency_map']) 80 | 81 | assert count == len(stimuli) 82 | -------------------------------------------------------------------------------- /tests/test_torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from packaging import version 3 | from scipy.ndimage import gaussian_filter as scipy_filter 4 | import torch 5 | 6 | import hypothesis 7 | from hypothesis import given, strategies as st 8 | from hypothesis.extra import numpy as hypothesis_np 9 | import pytest 10 | 11 | from pysaliency.torch_utils import gaussian_filter, gaussian_filter_1d_new_torch, gaussian_filter_1d_old_torch 12 | 13 | 14 | @pytest.fixture(params=[20.0]) 15 | def sigma(request): 16 | return request.param 17 | 18 | 19 | @pytest.fixture(params=[torch.float64, torch.float32]) 20 | def dtype(request): 21 | return request.param 22 | 23 | 24 | def test_gaussian_filter(sigma, dtype): 25 | #window_radius = int(sigma*4) 26 | test_data = 10*np.ones((4, 1, 100, 100)) 27 | test_data += np.random.randn(4, 1, 100, 100) 28 | 29 | test_tensor = torch.tensor(test_data, dtype=dtype) 30 | 31 | output = gaussian_filter( 32 | tensor=test_tensor, 33 | sigma=torch.tensor(sigma), 34 | truncate=4, 35 | dim=[2, 3], 36 | ).detach().cpu().numpy()[0, 0, :, :] 37 | 38 | scipy_out = scipy_filter(test_data[0, 0], sigma, mode='nearest') 39 | 40 | if dtype == torch.float32: 41 | rtol = 5e-6 42 | else: 43 | rtol = 1e-7 44 | np.testing.assert_allclose(output, scipy_out, rtol=rtol) 45 | 46 | 47 | @pytest.mark.skipif( 48 | version.parse(torch.__version__) < version.parse('1.7') # new code doesn't work because no `torch.movedim` 49 | or version.parse(torch.__version__) >= version.parse('1.11'), # old code doesn't work because torch's conv1d got stricter about input shape 50 | reason="torch either too new for old implementation or too old for new implementation" 51 | ) 52 | @given(hypothesis_np.arrays( 53 | dtype=hypothesis_np.floating_dtypes(sizes=(32, 64), endianness='='), 54 | shape=st.tuples( 55 | st.integers(min_value=1, max_value=100), 56 | st.just(1), 57 | st.integers(min_value=1, max_value=100), 58 | st.integers(min_value=1, max_value=100) 59 | )), 60 | st.floats(allow_nan=False, allow_infinity=False, min_value=0.01, max_value=50), 61 | st.integers(min_value=2, max_value=3), 62 | ) 63 | #@hypothesis.settings(verbosity=hypothesis.Verbosity.verbose) 64 | @hypothesis.settings(deadline=5000) 65 | def test_compare_gaussian_1d_implementations(data, sigma, dim): 66 | data_tensor = torch.tensor(data) 67 | old_data = gaussian_filter_1d_old_torch(data_tensor, sigma=sigma, dim=dim).detach().cpu().numpy() 68 | new_data = gaussian_filter_1d_new_torch(data_tensor, sigma=sigma, dim=dim).detach().cpu().numpy() 69 | 70 | np.testing.assert_allclose(old_data, new_data) --------------------------------------------------------------------------------