├── ABXpy ├── database │ ├── __init__.py │ └── database.py ├── distances │ ├── __init__.py │ ├── metrics │ │ ├── __init__.py │ │ ├── cosine.py │ │ ├── utw.py │ │ ├── dtw │ │ │ └── dtw.pyx │ │ └── kullback_leibler.py │ └── example_distances.py ├── h5tools │ ├── __init__.py │ ├── np2h5.py │ └── h52np.py ├── dbfun │ ├── __init__.py │ ├── dbfun.py │ ├── dbfun_column.py │ ├── lookuptable_connector.py │ └── dbfun_compute.py ├── misc │ ├── __init__.py │ ├── type_fitting.py │ ├── progress_display.py │ ├── any2h5features.py │ └── items.py ├── sampling │ ├── __init__.py │ └── sampler.py ├── sideop │ ├── __init__.py │ ├── filter_manager.py │ └── regressor_manager.py ├── __init__.py ├── verify.py ├── distance.py ├── score.py └── analyze.py ├── .conda ├── conda_build_config.yaml └── meta.yaml ├── doc ├── readthedocs-pip-requirements.txt ├── _templates │ └── footer.html ├── ABXpy.sampling.rst ├── ABXpy.database.rst ├── ABXpy.distances.metrics.rst ├── ABXpy.distances.rst ├── index.rst ├── ABXpy.misc.rst ├── ABXpy.sideop.rst ├── ABXpy.h5tools.rst ├── ABXpy.rst ├── ABXpy.dbfun.rst ├── NumberOfCores.rst ├── FilesFormat.rst ├── make.bat ├── Makefile └── conf.py ├── test ├── frozen_files │ ├── data.abx │ ├── data.score │ ├── data.distance │ ├── data.features │ ├── data.item │ └── data.csv ├── test_examples.py ├── test_dtw.py ├── test_score.py ├── test_sampling.py ├── test_analyse.py └── test_task.py ├── examples ├── example_items │ ├── data.features │ └── data.item ├── complete_run.sh └── complete_run.py ├── setup.cfg ├── environment.yml ├── Makefile ├── .travis.yml ├── CHANGELOG.rst ├── .gitignore ├── CONTRIBUTING.rst ├── LICENSE.txt ├── setup.py ├── .gitlab-ci.yml ├── README.rst └── bin └── ABXrun.py /ABXpy/database/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ABXpy/distances/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ABXpy/h5tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.6 3 | - 3.7 4 | - 3.8 5 | -------------------------------------------------------------------------------- /ABXpy/dbfun/__init__.py: -------------------------------------------------------------------------------- 1 | from . import dbfun_compute, dbfun_lookuptable 2 | -------------------------------------------------------------------------------- /doc/readthedocs-pip-requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx>=0.6 2 | numpydoc>=0.4 3 | numpy>=1.6 4 | mock>=1.0 5 | -------------------------------------------------------------------------------- /test/frozen_files/data.abx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bootphon/ABXpy/HEAD/test/frozen_files/data.abx -------------------------------------------------------------------------------- /test/frozen_files/data.score: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bootphon/ABXpy/HEAD/test/frozen_files/data.score -------------------------------------------------------------------------------- /test/frozen_files/data.distance: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bootphon/ABXpy/HEAD/test/frozen_files/data.distance -------------------------------------------------------------------------------- /test/frozen_files/data.features: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bootphon/ABXpy/HEAD/test/frozen_files/data.features -------------------------------------------------------------------------------- /doc/_templates/footer.html: -------------------------------------------------------------------------------- 1 | {%- extends "sphinx_rtd_theme/footer.html" %} 2 | 3 | {% set show_sphinx = False %} 4 | -------------------------------------------------------------------------------- /examples/example_items/data.features: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bootphon/ABXpy/HEAD/examples/example_items/data.features -------------------------------------------------------------------------------- /ABXpy/misc/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains several useful functions and classes that 2 | dont fit in any other modules. 3 | 4 | """ 5 | -------------------------------------------------------------------------------- /ABXpy/sampling/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implement an incremental sampler used to approximate the task 2 | and randomly select a portion of the triplets. 3 | """ 4 | -------------------------------------------------------------------------------- /ABXpy/distances/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Aug 18 22:31:59 2013 4 | 5 | @author: thomas 6 | 7 | This file only serves to signal that the content of the folder is a Python package. 8 | """ 9 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --verbose --doctest-modules --cov=ABXpy --cov-report=html --cov-report=term:skip-covered 6 | testpaths = ABXpy test 7 | python_files = test/*.py 8 | 9 | [build_sphinx] 10 | source-dir = doc 11 | build-dir = build/doc 12 | -------------------------------------------------------------------------------- /ABXpy/sideop/__init__.py: -------------------------------------------------------------------------------- 1 | """This module contains the filter and regressor managers used by 2 | task to apply the filters and regressors. Both those classes use a 3 | side operation manager that implement the generic functions. This 4 | allow to apply the filters and regressors as early as possible during 5 | the triplet generation to optimise the performances. 6 | 7 | """ 8 | -------------------------------------------------------------------------------- /doc/ABXpy.sampling.rst: -------------------------------------------------------------------------------- 1 | sampling Package 2 | ================ 3 | 4 | :mod:`sampling` Package 5 | ----------------------- 6 | 7 | .. automodule:: ABXpy.sampling 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`sampler` Module 13 | --------------------- 14 | 15 | .. automodule:: ABXpy.sampling.sampler 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | -------------------------------------------------------------------------------- /doc/ABXpy.database.rst: -------------------------------------------------------------------------------- 1 | database Package 2 | ================ 3 | 4 | :mod:`database` Package 5 | ----------------------- 6 | 7 | .. automodule:: ABXpy.database 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`database` Module 13 | ---------------------- 14 | 15 | .. automodule:: ABXpy.database.database 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | -------------------------------------------------------------------------------- /doc/ABXpy.distances.metrics.rst: -------------------------------------------------------------------------------- 1 | metrics Package 2 | =============== 3 | 4 | :mod:`metrics` Package 5 | ---------------------- 6 | 7 | .. automodule:: ABXpy.distances.metrics 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`cosine` Module 13 | -------------------- 14 | 15 | .. automodule:: ABXpy.distances.metrics.cosine 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ABXpy 2 | channels: 3 | - coml 4 | - conda-forge 5 | dependencies: 6 | - python>=3.6 7 | - cython>=0.20.1 8 | - editdistance 9 | - h5features>=1.3 10 | - h5py==2.* 11 | - numpy>=1.8.1 12 | - numpydoc 13 | - mock 14 | - pandas>=0.13.1 15 | - pip 16 | - pytables 17 | - pytest>=2.6 18 | - pytest-cov 19 | - pytest-runner 20 | - scipy>=0.14.0 21 | - setuptools 22 | - sphinx_rtd_theme 23 | -------------------------------------------------------------------------------- /doc/ABXpy.distances.rst: -------------------------------------------------------------------------------- 1 | distances Package 2 | ================= 3 | 4 | :mod:`distances` Package 5 | ------------------------ 6 | 7 | .. automodule:: ABXpy.distances 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`distances` Module 13 | ----------------------- 14 | 15 | .. automodule:: ABXpy.distances.distances 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | Subpackages 21 | ----------- 22 | 23 | .. toctree:: 24 | 25 | ABXpy.distances.metrics 26 | 27 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # simple makefile to simplify repetitive build env management tasks under posix 2 | 3 | PYTHON ?= python 4 | 5 | .PHONY: build install develop doc test clean 6 | 7 | build: 8 | $(PYTHON) setup.py build 9 | 10 | install: build 11 | $(PYTHON) setup.py install 12 | 13 | develop: build 14 | $(PYTHON) setup.py develop 15 | 16 | doc: build 17 | $(PYTHON) setup.py build_sphinx 18 | 19 | test: build 20 | $(PYTHON) setup.py test 21 | 22 | clean: 23 | $(PYTHON) setup.py clean 24 | find . -name '*.pyc' -delete 25 | find . -name '*.so' -delete 26 | find . -name __pycache__ -exec rm -rf {} + 27 | rm -rf .eggs *.egg-info 28 | rm -rf ABXpy/distances/metrics/dtw/*.c 29 | rm -rf build dist htmlcov .coverage* 30 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. ABXpy documentation master file, created by 2 | sphinx-quickstart on Wed May 7 22:50:13 2014. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to ABXpy's documentation! 7 | ================================= 8 | 9 | .. note:: 10 | 11 | For installation instructions please see the `README 12 | `_. 13 | 14 | 15 | Contents 16 | -------- 17 | 18 | .. toctree:: 19 | :maxdepth: 3 20 | 21 | ABXpy 22 | FilesFormat 23 | NumberOfCores 24 | 25 | 26 | Indices and tables 27 | ------------------ 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /ABXpy/misc/type_fitting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fit_integer_type(n, is_signed=True): 5 | """Returns the minimal space needed to store integers of maximal value n""" 6 | 7 | if is_signed: 8 | m = 1 9 | types = [np.int8, np.int16, np.int32, np.int64] 10 | else: 11 | m = 0 12 | types = [np.uint8, np.uint16, np.uint32, np.uint64] 13 | 14 | if n < 2 ** (8 - m): 15 | return types[0] 16 | elif n < 2 ** (16 - m): 17 | return types[1] 18 | elif n < 2 ** (32 - m): 19 | return types[2] 20 | elif n < 2 ** (64 - m): 21 | return types[3] 22 | else: 23 | raise ValueError( 24 | 'Values are too big to be represented by 64 bits integers!') 25 | -------------------------------------------------------------------------------- /doc/ABXpy.misc.rst: -------------------------------------------------------------------------------- 1 | misc Package 2 | ============ 3 | 4 | :mod:`misc` Package 5 | ------------------- 6 | 7 | .. automodule:: ABXpy.misc 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`progress_display` Module 13 | ------------------------------ 14 | 15 | .. automodule:: ABXpy.misc.progress_display 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | :mod:`tinytree` Module 21 | ---------------------- 22 | 23 | .. automodule:: ABXpy.misc.tinytree 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | :mod:`type_fitting` Module 29 | -------------------------- 30 | 31 | .. automodule:: ABXpy.misc.type_fitting 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | -------------------------------------------------------------------------------- /doc/ABXpy.sideop.rst: -------------------------------------------------------------------------------- 1 | sideop Package 2 | ============== 3 | 4 | :mod:`sideop` Package 5 | --------------------- 6 | 7 | .. automodule:: ABXpy.sideop 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`filter_manager` Module 13 | ---------------------------- 14 | 15 | .. automodule:: ABXpy.sideop.filter_manager 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | :mod:`regressor_manager` Module 21 | ------------------------------- 22 | 23 | .. automodule:: ABXpy.sideop.regressor_manager 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | :mod:`side_operations_manager` Module 29 | ------------------------------------- 30 | 31 | .. automodule:: ABXpy.sideop.side_operations_manager 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | -------------------------------------------------------------------------------- /examples/complete_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This test contains a full run of the ABX pipeline in command line 4 | # with randomly created database and features 5 | 6 | set -e 7 | 8 | cd "$( dirname "${BASH_SOURCE[0]}" )" 9 | 10 | # input files already here 11 | item=example_items/data.item 12 | features=example_items/data.features 13 | 14 | # output files produced by ABX 15 | task=example_items/data.abx 16 | distance=example_items/data.distance 17 | score=example_items/data.score 18 | analyze=example_items/data.csv 19 | 20 | # generating task file 21 | abx-task $item $task --verbose --on c0 --across c1 --by c2 22 | 23 | # computing distances 24 | abx-distance $features $task $distance --normalization 1 --njobs 1 25 | 26 | # calculating the score 27 | abx-score $task $distance $score 28 | 29 | # collapsing the results 30 | abx-analyze $score $task $analyze 31 | -------------------------------------------------------------------------------- /test/test_examples.py: -------------------------------------------------------------------------------- 1 | """This file contains test for the examples of the package""" 2 | 3 | import os 4 | import pytest 5 | import shutil 6 | import subprocess 7 | 8 | 9 | # when testing the conda recipe, the example scripts are not here, 10 | # and the test failed if not skipped 11 | @pytest.mark.skipif( 12 | 'CONDA_BUILD_STATE' in os.environ, 13 | reason="no example scripts during conda build") 14 | @pytest.mark.parametrize('ext', ['.py', '.sh']) 15 | def test_complete_run(ext): 16 | folder = os.path.join( 17 | os.path.dirname(os.path.dirname(__file__)), 'examples') 18 | script = os.path.join(folder, 'complete_run' + ext) 19 | 20 | assert os.path.isfile(script) 21 | 22 | try: 23 | shutil.rmtree('./example_items', ignore_errors=True) 24 | subprocess.check_call([script]) 25 | finally: 26 | shutil.rmtree('./example_items', ignore_errors=True) 27 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | notifications: 2 | email: false 3 | 4 | # no need of git history 5 | git: 6 | depth: 1 7 | 8 | # test only the master branch 9 | branches: 10 | only: 11 | - master 12 | 13 | install: 14 | # setup anaconda 15 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh 16 | - bash miniconda.sh -b -p $HOME/miniconda 17 | - source "$HOME/miniconda/etc/profile.d/conda.sh" 18 | - hash -r 19 | - conda config --set always_yes yes --set changeps1 no 20 | - conda update -q conda 21 | # useful for debugging any issues with conda 22 | - conda info -a 23 | # install ABXpy environment 24 | - conda env create -n test-environment -f environment.yml 25 | - conda activate test-environment 26 | - pip install codecov 27 | # install ABXpy 28 | - python setup.py install 29 | 30 | script: 31 | - python setup.py test 32 | 33 | after_success: 34 | - codecov 35 | -------------------------------------------------------------------------------- /doc/ABXpy.h5tools.rst: -------------------------------------------------------------------------------- 1 | h5tools Package 2 | =============== 3 | 4 | :mod:`h5tools` Package 5 | ---------------------- 6 | 7 | .. automodule:: ABXpy.h5tools 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`h52np` Module 13 | ------------------- 14 | 15 | .. automodule:: ABXpy.h5tools.h52np 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | :mod:`h5_handler` Module 21 | ------------------------ 22 | 23 | .. automodule:: ABXpy.h5tools.h5_handler 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | :mod:`h5io` Module 29 | ------------------ 30 | 31 | .. automodule:: ABXpy.h5tools.h5io 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | :mod:`np2h5` Module 37 | ------------------- 38 | 39 | .. automodule:: ABXpy.h5tools.np2h5 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | -------------------------------------------------------------------------------- /doc/ABXpy.rst: -------------------------------------------------------------------------------- 1 | ABXpy Package 2 | ============= 3 | 4 | :mod:`ABXpy` Package 5 | -------------------- 6 | 7 | .. automodule:: ABXpy.__init__ 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`task` Module 13 | ------------------ 14 | 15 | .. automodule:: ABXpy.task 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | :mod:`score` Module 21 | ------------------- 22 | 23 | .. automodule:: ABXpy.score 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | :mod:`analyze` Module 29 | --------------------- 30 | 31 | .. automodule:: ABXpy.analyze 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Subpackages 37 | ----------- 38 | 39 | .. toctree:: 40 | 41 | ABXpy.distances 42 | ABXpy.sampling 43 | ABXpy.database 44 | ABXpy.h5tools 45 | ABXpy.dbfun 46 | ABXpy.sideop 47 | ABXpy.misc 48 | 49 | -------------------------------------------------------------------------------- /ABXpy/dbfun/dbfun.py: -------------------------------------------------------------------------------- 1 | """Abstract API for getting functions of data in a database 2 | 3 | API define one attribute (input_names) and one method (evaluate) 4 | 5 | """ 6 | 7 | 8 | class DBfun(object): 9 | def __init__(self, input_names): 10 | self.input_names = input_names 11 | 12 | # input must contain a dictionary whos keys are the input_names 13 | def evaluate(self, inputs_dict): 14 | pass 15 | # do some generic checks here ? 16 | # e.g. set(input_names) == set(input_dict.keys()) 17 | # or each element in inputs_dict is an array with the same number of 18 | # lines (and possibly different types and number of columns) 19 | 20 | # should return at least the number of outputs and if possible an ordered 21 | # list of output names + a dictionary {output_name: index} containing all 22 | # indexed outputs and their indexes 23 | def output_specs(self): 24 | pass 25 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | ChangeLog 3 | ========= 4 | 5 | ABXpy-0.4.3 6 | =========== 7 | 8 | * no more compatibility with **python2** (and removed dependency to future/past). 9 | 10 | * new feature: add *editdistance* as available distance. 11 | 12 | * bugfix with python-3.8 and *ast.Module*. 13 | 14 | * fixed a lot of deprecation warnings. 15 | 16 | * documentation moved to https://coml.lscp.ens.fr/docs/abx 17 | 18 | 19 | ABXpy-0.4.2 20 | =========== 21 | 22 | * compatibility with **python3** 23 | 24 | * releases are now deployed on `conda 25 | `_ for linux and macos. 26 | 27 | * documentation is hosted on https://coml.lscp.ens.fr/abx 28 | 29 | * bugfixes: 30 | 31 | * bugfix in the Kullback Leibler divergence metric with numpy array shape 32 | 33 | * bugfix: cosine.py returned shape (1,1,1,1) instead of (1,1) when 34 | metric(x,y) is shape (1,1) 35 | 36 | 37 | ABXpy-0.4.1 38 | =========== 39 | 40 | No ChangeLog beyond this release. 41 | -------------------------------------------------------------------------------- /ABXpy/distances/example_distances.py: -------------------------------------------------------------------------------- 1 | import ABXpy.distances.metrics.dtw as dtw 2 | import ABXpy.distances.metrics.cosine as cosine 3 | import ABXpy.distances.metrics.kullback_leibler as kullback_leibler 4 | import numpy as np 5 | 6 | 7 | def dtw_cosine(x, y): 8 | """Dynamic time warping cosine distance 9 | 10 | The "feature" dimension is along the columns and the "time" 11 | dimension along the lines of arrays x and y 12 | 13 | """ 14 | if x.shape[0] > 0 and y.shape[0] > 0: 15 | # x and y are not empty 16 | d = dtw.dtw(x, y, cosine.cosine_distance) 17 | elif x.shape[0] == y.shape[0]: 18 | # both x and y are empty 19 | d = 0 20 | else: 21 | # x or y is empty 22 | d = np.inf 23 | return d 24 | 25 | 26 | def dtw_kl_divergence(x, y): 27 | """Kullback-Leibler divergence""" 28 | if x.shape[0] > 0 and y.shape[0] > 0: 29 | d = dtw.dtw(x, y, kullback_leibler.kl_divergence) 30 | elif x.shape[0] == y.shape[0]: 31 | d = 0 32 | else: 33 | d = np.inf 34 | return d 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac specific files 2 | .DS_Store 3 | 4 | # Spyder IDE files 5 | .spyderproject 6 | .spyderworkspace 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | bin/ 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | .eggs/ 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage* 41 | coverage.xml 42 | .cache 43 | nosetests.xml 44 | 45 | # Translations 46 | *.mo 47 | 48 | # Mr Developer 49 | .mr.developer.cfg 50 | .project 51 | .pydevproject 52 | 53 | # Rope 54 | .ropeproject 55 | 56 | # Django stuff: 57 | *.log 58 | *.pot 59 | 60 | # pyTest cache 61 | test/__pycache__ 62 | 63 | # Emacs 64 | *~ 65 | \#*\# 66 | 67 | # Cythonized output 68 | ABXpy/distances/metrics/dtw/*.c 69 | ABXpy/distances/metrics/install/*.c 70 | 71 | # Example items 72 | examples/example_items/data.* 73 | -------------------------------------------------------------------------------- /doc/ABXpy.dbfun.rst: -------------------------------------------------------------------------------- 1 | dbfun Package 2 | ============= 3 | 4 | :mod:`dbfun` Package 5 | -------------------- 6 | 7 | .. automodule:: ABXpy.dbfun 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | 12 | :mod:`dbfun` Module 13 | ------------------- 14 | 15 | .. automodule:: ABXpy.dbfun.dbfun 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | 20 | :mod:`dbfun_column` Module 21 | -------------------------- 22 | 23 | .. automodule:: ABXpy.dbfun.dbfun_column 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | :mod:`dbfun_compute` Module 29 | --------------------------- 30 | 31 | .. automodule:: ABXpy.dbfun.dbfun_compute 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | :mod:`dbfun_lookuptable` Module 37 | ------------------------------- 38 | 39 | .. automodule:: ABXpy.dbfun.dbfun_lookuptable 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | :mod:`lookuptable_connector` Module 45 | ----------------------------------- 46 | 47 | .. automodule:: ABXpy.dbfun.lookuptable_connector 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributing 3 | ============ 4 | 5 | If you discover issues, have ideas for improvements or new features, 6 | or want to contribute a new module, please report them to the `issue 7 | tracker`_ of the repository or submit a pull request. Please, try to 8 | follow these guidelines when you do so. 9 | 10 | Issue reporting 11 | =============== 12 | 13 | * Check that the issue has not already been reported or fixed in the 14 | latest code (a.k.a. `master`_). 15 | * Include any relevant code to the issue summary, when possible write 16 | a broken test case. 17 | 18 | Pull requests 19 | ============= 20 | 21 | * Read how to properly contribute to open source projects on Github `here`_. 22 | * Use a topic branch to easily amend a pull request later, if necessary. 23 | * Use the same coding conventions as the rest of the project. 24 | * Open a `pull request`_ that relates to *only* one subject. 25 | 26 | .. _issue tracker: https://github.com/bootphon/ABXpy/issues 27 | .. _master: https://github.com/bootphon/ABXpy/tree/master 28 | .. _here: http://gun.io/blog/how-to-github-fork-branch-and-pull-request 29 | .. _pull request: https://help.github.com/articles/using-pull-requests 30 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2014-2019 Thomas Schatz, Emmanuel Dupoux and Roland Thiollière. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/test_dtw.py: -------------------------------------------------------------------------------- 1 | """Module for testing the dtw module""" 2 | 3 | import ABXpy.distances.metrics.dtw as dtw 4 | import numpy as np 5 | 6 | 7 | def test_small(): 8 | res = dtw._dtw(1, 1, np.ones((1, 1)), normalized=False) 9 | assert res == 1 10 | 11 | res = dtw._dtw(1, 1, np.ones((1, 1)), normalized=True) 12 | assert res == 1 13 | 14 | 15 | def test_normalized(): 16 | dists = np.ones((5, 5)) * 2 17 | dists = dists - np.diag(np.ones((5,))) 18 | res = dtw._dtw(5, 5, dists, normalized=True) 19 | assert res == 1 20 | 21 | dists_start = np.ones((5, 2)) * 2 22 | dists_start[0, :] = 1 23 | dists_start = np.concatenate([dists_start, dists], axis=1) 24 | res = dtw._dtw(5, 7, dists_start, normalized=True) 25 | assert res == 1 26 | 27 | dists_start = np.ones((2, 5)) * 2 28 | dists_start[:, 0] = 1 29 | dists_start = np.concatenate([dists_start, dists], axis=0) 30 | res = dtw._dtw(7, 5, dists_start, normalized=True) 31 | assert res == 1 32 | 33 | dists_mid = np.ones((5, 2)) * 2 34 | dists_mid[3, :] = 1 35 | dists_mid = np.concatenate([dists[:, :3], dists_mid, dists[:, 3:]], axis=1) 36 | res = dtw._dtw(5, 7, dists_mid, normalized=True) 37 | assert res == 1 38 | -------------------------------------------------------------------------------- /ABXpy/dbfun/dbfun_column.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Dec 20 13:36:52 2013 4 | 5 | @author: Thomas Schatz 6 | """ 7 | 8 | import numpy as np 9 | 10 | from . import dbfun 11 | 12 | 13 | class DBfun_Column(dbfun.DBfun): 14 | def __init__(self, name, db=None, column=None, indexed=True): 15 | self.input_names = [name] 16 | self.n_outputs = 1 17 | if indexed: 18 | index = list(set(db[column])) 19 | index.sort() 20 | self.index = index 21 | else: 22 | self.index = [] 23 | 24 | def output_specs(self): 25 | if self.index: 26 | indexes = {self.input_names[0]: self.index} 27 | else: 28 | indexes = {} 29 | return self.n_outputs, self.input_names, indexes 30 | 31 | # function for evaluating the column function given data for the context 32 | # context is a dictionary with just the right name/content associations 33 | def evaluate(self, context): 34 | if self.index: 35 | # FIXME optimize this 36 | return [np.array([self.index.index(e) 37 | for e in context[self.input_names[0]]])] 38 | else: 39 | return [context[self.input_names[0]]] 40 | -------------------------------------------------------------------------------- /test/test_score.py: -------------------------------------------------------------------------------- 1 | """This test script contains tests for the basic parameters of score.py""" 2 | 3 | import os 4 | import shutil 5 | 6 | import ABXpy.task 7 | import ABXpy.distances.distances as distances 8 | import ABXpy.distances.metrics.cosine as cosine 9 | import ABXpy.distances.metrics.dtw as dtw 10 | import ABXpy.score as score 11 | import ABXpy.misc.items as items 12 | 13 | 14 | def dtw_cosine_distance(x, y, normalized): 15 | return dtw.dtw(x, y, cosine.cosine_distance, normalized) 16 | 17 | 18 | def test_score(): 19 | try: 20 | if not os.path.exists('test_items'): 21 | os.makedirs('test_items') 22 | item_file = 'test_items/data.item' 23 | feature_file = 'test_items/data.features' 24 | distance_file = 'test_items/data.distance' 25 | scorefilename = 'test_items/data.score' 26 | taskfilename = 'test_items/data.abx' 27 | items.generate_db_and_feat(3, 3, 1, item_file, 2, 3, feature_file) 28 | task = ABXpy.task.Task(item_file, 'c0', 'c1', 'c2') 29 | task.generate_triplets() 30 | distances.compute_distances( 31 | feature_file, '/features/', taskfilename, 32 | distance_file, dtw_cosine_distance, 33 | normalized=True, n_cpu=3) 34 | score.score(taskfilename, distance_file, scorefilename) 35 | finally: 36 | shutil.rmtree('test_items', ignore_errors=True) 37 | -------------------------------------------------------------------------------- /ABXpy/distances/metrics/cosine.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Jan 22 01:47:42 2014 4 | 5 | @author: Thomas Schatz 6 | """ 7 | import numpy as np 8 | 9 | # FIXME change name to just distance ou distance_matrix? compute 10 | # cosine distances between all possible pairs of lines in the x and y 11 | # matrix x and y should be 2D numpy arrays with "features" on the 12 | # lines and "times" on the columns x, y must be float arrays 13 | 14 | 15 | def cosine_distance(x, y): 16 | assert (x.dtype == np.float64 and y.dtype == np.float64) or ( 17 | x.dtype == np.float32 and y.dtype == np.float32) 18 | x2 = np.sqrt(np.sum(x ** 2, axis=1)) 19 | y2 = np.sqrt(np.sum(y ** 2, axis=1)) 20 | ix = x2 == 0. 21 | iy = y2 == 0. 22 | d = np.dot(x, y.T) / (np.outer(x2, y2)) 23 | # DPX: to prevent the stupid scipy to collapse the array into scalar 24 | if d.shape == (1, 1): 25 | d = np.array([[np.float64(np.lib.scimath.arccos(d[0, 0]) / np.pi)]]).reshape((1, 1)) 26 | else: 27 | # costly in time (half of the time), so check if really useful for dtw 28 | d = np.float64(np.lib.scimath.arccos(d) / np.pi) 29 | 30 | d[ix, :] = 1. 31 | d[:, iy] = 1. 32 | for i in np.where(ix)[0]: 33 | d[i, iy] = 0. 34 | assert np.all(d >= 0) 35 | return d 36 | 37 | 38 | def normalize_cosine_distance(x, y): 39 | x /= x.sum(1).reshape(x.shape[0], 1) 40 | y /= y.sum(1).reshape(y.shape[0], 1) 41 | return cosine_distance(x, y) 42 | -------------------------------------------------------------------------------- /.conda/meta.yaml: -------------------------------------------------------------------------------- 1 | # Build the shennong conda package. Run with "conda build . -c coml" 2 | 3 | {% set name = 'abx' %} 4 | {% set data = load_setup_py_data() %} 5 | 6 | package: 7 | name: {{ name }} 8 | version: {{ data.get('version') }} 9 | 10 | source: 11 | path: .. 12 | 13 | build: 14 | entry_points: 15 | {% for entry in data.get('entry_points')['console_scripts'] %} 16 | - {{ entry }} 17 | {% endfor %} 18 | script: 19 | - conda install -c conda-forge editdistance -y 20 | - conda install -c coml h5features -y 21 | - python setup.py install 22 | 23 | requirements: 24 | build: 25 | - python {{ python }} 26 | - cython 27 | - h5py>=2.2.1 28 | - mock 29 | - numpy>=1.8.1 30 | - pandas>=0.13.1 31 | - pip 32 | - pytables 33 | - pytest 34 | - pytest-cov 35 | - pytest-runner 36 | - scipy>=0.14.0 37 | - setuptools 38 | 39 | run: 40 | - python {{ python }} 41 | - cython 42 | - h5py>=2.2.1 43 | - mock 44 | - numpy>=1.8.1 45 | - pandas>=0.13.1 46 | - pip 47 | - pytables 48 | - pytest 49 | - pytest-cov 50 | - pytest-runner 51 | - scipy>=0.14.0 52 | - setuptools 53 | 54 | test: 55 | imports: 56 | - ABXpy 57 | requires: 58 | - pytest>=2.6 59 | - pytest-cov 60 | source_files: 61 | - test 62 | commands: 63 | - abx-task -h 64 | - abx-score -h 65 | - abx-distance -h 66 | - pytest -vx 67 | 68 | about: 69 | home: {{ data.get('url') }} 70 | license: {{ data.get('license') }} 71 | summary: {{ data.get('description') }} 72 | -------------------------------------------------------------------------------- /examples/complete_run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Full run of ABX pipeline with randomly created database and features""" 3 | 4 | import os 5 | 6 | import ABXpy.task 7 | import ABXpy.distances.distances as distances 8 | import ABXpy.distances.metrics.cosine as cosine 9 | import ABXpy.distances.metrics.dtw as dtw 10 | import ABXpy.score as score 11 | import ABXpy.misc.items as items 12 | import ABXpy.analyze as analyze 13 | 14 | 15 | def dtw_cosine_distance(x, y, normalized): 16 | return dtw.dtw(x, y, cosine.cosine_distance, normalized) 17 | 18 | 19 | def fullrun(): 20 | if not os.path.exists('example_items'): 21 | os.makedirs('example_items') 22 | item_file = 'example_items/data.item' 23 | feature_file = 'example_items/data.features' 24 | distance_file = 'example_items/data.distance' 25 | scorefilename = 'example_items/data.score' 26 | taskfilename = 'example_items/data.abx' 27 | analyzefilename = 'example_items/data.csv' 28 | 29 | # deleting pre-existing files 30 | for f in [item_file, feature_file, distance_file, 31 | scorefilename, taskfilename, analyzefilename]: 32 | try: 33 | os.remove(f) 34 | except OSError: 35 | pass 36 | 37 | # running the evaluation 38 | items.generate_db_and_feat(3, 3, 5, item_file, 2, 2, feature_file) 39 | 40 | task = ABXpy.task.Task(item_file, 'c0', across='c1', by='c2') 41 | task.generate_triplets(taskfilename) 42 | 43 | distances.compute_distances( 44 | feature_file, 'features', taskfilename, 45 | distance_file, dtw_cosine_distance, 46 | normalized=True, n_cpu=1) 47 | 48 | score.score(taskfilename, distance_file, scorefilename) 49 | 50 | analyze.analyze(taskfilename, scorefilename, analyzefilename) 51 | 52 | 53 | if __name__ == '__main__': 54 | fullrun() 55 | -------------------------------------------------------------------------------- /doc/NumberOfCores.rst: -------------------------------------------------------------------------------- 1 | =========================================== 2 | How to choose the right number of CPU cores 3 | =========================================== 4 | 5 | The amount of cores you can ask for ``abx-distance`` is only limited by the 6 | number of cores available on the machine you're using (unless you have very big 7 | "BY" blocks in your features file, in which case you need to make sure 8 | Size-of-the-biggest-BY-block*number-of-cores does not exceed the available 9 | memory). 10 | 11 | The running time is going to be essentially 12 | number-of-distances-to-be-computed*time-required-to-compute-a-distance. You can 13 | get an estimate of that easily by taking a smaller item file, obtaining the 14 | number of distances to be computed with the original item file n and with the 15 | smaller file **n_small**. You then run the ABX evaluation for the smaller file 16 | and look at the time **t_small** it takes using only 1 core. Then you can 17 | estimate the time it will take for the bigger file using **n_cores** cores: **t 18 | = t_small*n/(n_small*n_cores)**. 19 | 20 | If the estimated time is too long given your deadlines, you can reduce the 21 | number of pairs to be computed without compromising your results by sampling the 22 | item file. You have to do it in a way that make sense given the ABX task you are 23 | using and the analyses you want to do. For example for a task "ON phoneme BY 24 | speaker, context": 25 | 26 | * For any (phoneme, speaker, context) triple that appears on more than **k** 27 | lines in the item file, you can keep **k** lines only (randomly sampled). 28 | **k=10** or even **k=5** should be largely sufficient, unless you are not 29 | averaging over speakers and contexts in your analyses. 30 | * You can also remove any line with a (phoneme, speaker, context) triple that 31 | appears only once in the item file, as those cannot be used to estimate 32 | symetrised scores (unless you are interested in the asymetries). 33 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import setuptools.command.build_ext 3 | import ABXpy 4 | 5 | 6 | class build_ext(setuptools.command.build_ext.build_ext): 7 | def finalize_options(self): 8 | setuptools.command.build_ext.build_ext.finalize_options(self) 9 | # Prevent numpy from thinking it is still in its setup process: 10 | __builtins__.__NUMPY_SETUP__ = False 11 | import numpy 12 | self.include_dirs.append(numpy.get_include()) 13 | 14 | 15 | setuptools.setup( 16 | name='ABXpy', 17 | version=ABXpy.version, 18 | author='Thomas Schatz', 19 | description='ABX discrimination task', 20 | long_description=open('README.rst').read(), 21 | url='https://github.com/bootphon/ABXpy', 22 | license='LICENSE.txt', 23 | 24 | packages=setuptools.find_packages(exclude='test'), 25 | 26 | # needed for cython/setuptools, see 27 | # http://docs.cython.org/en/latest/src/quickstart/build.html 28 | zip_safe=False, 29 | 30 | setup_requires=[ 31 | 'editdistance', 32 | 'cython', 33 | 'setuptools>=18.0', 34 | 'numpy>=1.9.0', 35 | 'pytest-runner' 36 | ], 37 | 38 | install_requires=[ 39 | 'h5py >= 2.2.1', 40 | 'numpy >= 1.8.0', 41 | 'pandas >= 0.13.1', 42 | 'scipy >= 0.13.0', 43 | 'tables', 44 | ], 45 | 46 | tests_require=[ 47 | 'h5features', 48 | 'pytest>=2.6', 49 | 'pytest-cov' 50 | ], 51 | 52 | ext_modules=[setuptools.Extension( 53 | 'ABXpy.distances.metrics.dtw', 54 | sources=['ABXpy/distances/metrics/dtw/dtw.pyx'], 55 | extra_compile_args=['-O3'])], 56 | 57 | cmdclass={'build_ext': build_ext}, 58 | 59 | entry_points={'console_scripts': [ 60 | 'abx-task = ABXpy.task:main', 61 | 'abx-distance = ABXpy.distance:main', 62 | 'abx-analyze = ABXpy.analyze:main', 63 | 'abx-score = ABXpy.score:main', 64 | ]} 65 | ) 66 | -------------------------------------------------------------------------------- /test/frozen_files/data.item: -------------------------------------------------------------------------------- 1 | #file onset offset #item c0 c1 c2 2 | s0 0.0 0.0 i0 c0_v0 c1_v0 c2_v0 3 | s1 0.0 0.0 i1 c0_v1 c1_v0 c2_v0 4 | s2 0.0 0.0 i2 c0_v2 c1_v0 c2_v0 5 | s3 0.0 0.0 i3 c0_v0 c1_v1 c2_v0 6 | s4 0.0 0.0 i4 c0_v1 c1_v1 c2_v0 7 | s5 0.0 0.0 i5 c0_v2 c1_v1 c2_v0 8 | s6 0.0 0.0 i6 c0_v0 c1_v2 c2_v0 9 | s7 0.0 0.0 i7 c0_v1 c1_v2 c2_v0 10 | s8 0.0 0.0 i8 c0_v2 c1_v2 c2_v0 11 | s9 0.0 0.0 i9 c0_v0 c1_v0 c2_v1 12 | s10 0.0 0.0 i10 c0_v1 c1_v0 c2_v1 13 | s11 0.0 0.0 i11 c0_v2 c1_v0 c2_v1 14 | s12 0.0 0.0 i12 c0_v0 c1_v1 c2_v1 15 | s13 0.0 0.0 i13 c0_v1 c1_v1 c2_v1 16 | s14 0.0 0.0 i14 c0_v2 c1_v1 c2_v1 17 | s15 0.0 0.0 i15 c0_v0 c1_v2 c2_v1 18 | s16 0.0 0.0 i16 c0_v1 c1_v2 c2_v1 19 | s17 0.0 0.0 i17 c0_v2 c1_v2 c2_v1 20 | s18 0.0 0.0 i18 c0_v0 c1_v0 c2_v2 21 | s19 0.0 0.0 i19 c0_v1 c1_v0 c2_v2 22 | s20 0.0 0.0 i20 c0_v2 c1_v0 c2_v2 23 | s21 0.0 0.0 i21 c0_v0 c1_v1 c2_v2 24 | s22 0.0 0.0 i22 c0_v1 c1_v1 c2_v2 25 | s23 0.0 0.0 i23 c0_v2 c1_v1 c2_v2 26 | s24 0.0 0.0 i24 c0_v0 c1_v2 c2_v2 27 | s25 0.0 0.0 i25 c0_v1 c1_v2 c2_v2 28 | s26 0.0 0.0 i26 c0_v2 c1_v2 c2_v2 29 | s27 0.0 0.0 i27 c0_v0 c1_v0 c2_v0 30 | s28 0.0 0.0 i28 c0_v1 c1_v0 c2_v0 31 | s29 0.0 0.0 i29 c0_v2 c1_v0 c2_v0 32 | s30 0.0 0.0 i30 c0_v0 c1_v1 c2_v0 33 | s31 0.0 0.0 i31 c0_v1 c1_v1 c2_v0 34 | s32 0.0 0.0 i32 c0_v2 c1_v1 c2_v0 35 | s33 0.0 0.0 i33 c0_v0 c1_v2 c2_v0 36 | s34 0.0 0.0 i34 c0_v1 c1_v2 c2_v0 37 | s35 0.0 0.0 i35 c0_v2 c1_v2 c2_v0 38 | s36 0.0 0.0 i36 c0_v0 c1_v0 c2_v1 39 | s37 0.0 0.0 i37 c0_v1 c1_v0 c2_v1 40 | s38 0.0 0.0 i38 c0_v2 c1_v0 c2_v1 41 | s39 0.0 0.0 i39 c0_v0 c1_v1 c2_v1 42 | s40 0.0 0.0 i40 c0_v1 c1_v1 c2_v1 43 | s41 0.0 0.0 i41 c0_v2 c1_v1 c2_v1 44 | s42 0.0 0.0 i42 c0_v0 c1_v2 c2_v1 45 | s43 0.0 0.0 i43 c0_v1 c1_v2 c2_v1 46 | s44 0.0 0.0 i44 c0_v2 c1_v2 c2_v1 47 | s45 0.0 0.0 i45 c0_v0 c1_v0 c2_v2 48 | s46 0.0 0.0 i46 c0_v1 c1_v0 c2_v2 49 | s47 0.0 0.0 i47 c0_v2 c1_v0 c2_v2 50 | s48 0.0 0.0 i48 c0_v0 c1_v1 c2_v2 51 | s49 0.0 0.0 i49 c0_v1 c1_v1 c2_v2 52 | s50 0.0 0.0 i50 c0_v2 c1_v1 c2_v2 53 | s51 0.0 0.0 i51 c0_v0 c1_v2 c2_v2 54 | s52 0.0 0.0 i52 c0_v1 c1_v2 c2_v2 55 | s53 0.0 0.0 i53 c0_v2 c1_v2 c2_v2 56 | -------------------------------------------------------------------------------- /examples/example_items/data.item: -------------------------------------------------------------------------------- 1 | #file onset offset #item c0 c1 c2 2 | s0 0.0 0.0 i0 c0_v0 c1_v0 c2_v0 3 | s1 0.0 0.0 i1 c0_v1 c1_v0 c2_v0 4 | s2 0.0 0.0 i2 c0_v2 c1_v0 c2_v0 5 | s3 0.0 0.0 i3 c0_v0 c1_v1 c2_v0 6 | s4 0.0 0.0 i4 c0_v1 c1_v1 c2_v0 7 | s5 0.0 0.0 i5 c0_v2 c1_v1 c2_v0 8 | s6 0.0 0.0 i6 c0_v0 c1_v2 c2_v0 9 | s7 0.0 0.0 i7 c0_v1 c1_v2 c2_v0 10 | s8 0.0 0.0 i8 c0_v2 c1_v2 c2_v0 11 | s9 0.0 0.0 i9 c0_v0 c1_v0 c2_v1 12 | s10 0.0 0.0 i10 c0_v1 c1_v0 c2_v1 13 | s11 0.0 0.0 i11 c0_v2 c1_v0 c2_v1 14 | s12 0.0 0.0 i12 c0_v0 c1_v1 c2_v1 15 | s13 0.0 0.0 i13 c0_v1 c1_v1 c2_v1 16 | s14 0.0 0.0 i14 c0_v2 c1_v1 c2_v1 17 | s15 0.0 0.0 i15 c0_v0 c1_v2 c2_v1 18 | s16 0.0 0.0 i16 c0_v1 c1_v2 c2_v1 19 | s17 0.0 0.0 i17 c0_v2 c1_v2 c2_v1 20 | s18 0.0 0.0 i18 c0_v0 c1_v0 c2_v2 21 | s19 0.0 0.0 i19 c0_v1 c1_v0 c2_v2 22 | s20 0.0 0.0 i20 c0_v2 c1_v0 c2_v2 23 | s21 0.0 0.0 i21 c0_v0 c1_v1 c2_v2 24 | s22 0.0 0.0 i22 c0_v1 c1_v1 c2_v2 25 | s23 0.0 0.0 i23 c0_v2 c1_v1 c2_v2 26 | s24 0.0 0.0 i24 c0_v0 c1_v2 c2_v2 27 | s25 0.0 0.0 i25 c0_v1 c1_v2 c2_v2 28 | s26 0.0 0.0 i26 c0_v2 c1_v2 c2_v2 29 | s27 0.0 0.0 i27 c0_v0 c1_v0 c2_v0 30 | s28 0.0 0.0 i28 c0_v1 c1_v0 c2_v0 31 | s29 0.0 0.0 i29 c0_v2 c1_v0 c2_v0 32 | s30 0.0 0.0 i30 c0_v0 c1_v1 c2_v0 33 | s31 0.0 0.0 i31 c0_v1 c1_v1 c2_v0 34 | s32 0.0 0.0 i32 c0_v2 c1_v1 c2_v0 35 | s33 0.0 0.0 i33 c0_v0 c1_v2 c2_v0 36 | s34 0.0 0.0 i34 c0_v1 c1_v2 c2_v0 37 | s35 0.0 0.0 i35 c0_v2 c1_v2 c2_v0 38 | s36 0.0 0.0 i36 c0_v0 c1_v0 c2_v1 39 | s37 0.0 0.0 i37 c0_v1 c1_v0 c2_v1 40 | s38 0.0 0.0 i38 c0_v2 c1_v0 c2_v1 41 | s39 0.0 0.0 i39 c0_v0 c1_v1 c2_v1 42 | s40 0.0 0.0 i40 c0_v1 c1_v1 c2_v1 43 | s41 0.0 0.0 i41 c0_v2 c1_v1 c2_v1 44 | s42 0.0 0.0 i42 c0_v0 c1_v2 c2_v1 45 | s43 0.0 0.0 i43 c0_v1 c1_v2 c2_v1 46 | s44 0.0 0.0 i44 c0_v2 c1_v2 c2_v1 47 | s45 0.0 0.0 i45 c0_v0 c1_v0 c2_v2 48 | s46 0.0 0.0 i46 c0_v1 c1_v0 c2_v2 49 | s47 0.0 0.0 i47 c0_v2 c1_v0 c2_v2 50 | s48 0.0 0.0 i48 c0_v0 c1_v1 c2_v2 51 | s49 0.0 0.0 i49 c0_v1 c1_v1 c2_v2 52 | s50 0.0 0.0 i50 c0_v2 c1_v1 c2_v2 53 | s51 0.0 0.0 i51 c0_v0 c1_v2 c2_v2 54 | s52 0.0 0.0 i52 c0_v1 c1_v2 c2_v2 55 | s53 0.0 0.0 i53 c0_v2 c1_v2 c2_v2 56 | -------------------------------------------------------------------------------- /ABXpy/distances/metrics/utw.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def utw(x, y, div): 4 | """Uniform Time Warping 5 | 6 | Parameters 7 | ---------- 8 | x : numpy.Array 9 | Size m_x by n 10 | y : numpy.Array 11 | Size m_y by n 12 | div : function 13 | Takes two m by n arrays and returns a m by 1 array 14 | containing the distances between rows 1, rows 2, ..., 15 | rows m of the two arrays. Does not need to be a metric 16 | in the mathematical sense 17 | 18 | Returns 19 | ------- 20 | dis : float 21 | The UTW distance between x and y according to 22 | distance function div 23 | """ 24 | if x.shape[0] > y.shape[0]: 25 | z = x, y 26 | y, x = z 27 | n1 = x.shape[0] 28 | n2 = y.shape[0] 29 | i_half, j_half, i_whole, j_whole = distance_coordinates(n1, n2) 30 | dis = 0 31 | if i_half.size > 0: 32 | dis = dis + np.sum(div(x[i_half, :], y[j_half, :])) / 2. 33 | if i_whole.size > 0: 34 | dis = dis + np.sum(div(x[i_whole, :], y[j_whole, :])) 35 | return dis 36 | 37 | 38 | def distance_coordinates(n1, n2): 39 | assert(n1 <= n2) 40 | l1 = np.arange(n2) 41 | l2 = n1/np.float(n2)*(np.arange(n2)+0.5) 42 | l3 = np.floor(l2).astype(np.int) 43 | integers = np.where(l2 == l3)[0] 44 | non_integers = np.where(l2 != l3)[0] 45 | l3i = l3[integers] 46 | l1i = l1[integers] 47 | i_half = np.concatenate([l3i-1, l3i]) 48 | j_half = np.concatenate([l1i, l1i]) 49 | i_whole = l3[non_integers] 50 | j_whole = l1[non_integers] 51 | return i_half, j_half, i_whole, j_whole 52 | 53 | 54 | # def test(): 55 | # metric = lambda x, y: np.sum(np.abs(x-y), axis=1) 56 | # x = np.array([[0, 0], 57 | # [0, -1], 58 | # [0, 0], 59 | # [4, 0], 60 | # [0, 1], 61 | # [-4, 5], 62 | # [5, 0]]) 63 | # y = np.array([[0, 1], 64 | # [0, 1], 65 | # [0, -2], 66 | # [1, 1]]) 67 | # assert(utw(x, y, metric) == 26.5) 68 | # assert(utw(y, x, metric) == 26.5) 69 | 70 | # test() 71 | -------------------------------------------------------------------------------- /ABXpy/misc/progress_display.py: -------------------------------------------------------------------------------- 1 | """Displays the progress during the computing.""" 2 | 3 | import os 4 | import sys 5 | import collections 6 | 7 | 8 | class ProgressDisplay(object): 9 | def __init__(self): 10 | self.message = collections.OrderedDict() 11 | self.total = collections.OrderedDict() 12 | self.count = collections.OrderedDict() 13 | self.init = True 14 | 15 | # FIXME the goal of this is to determine whether using \033[A will 16 | # move the standard output n lines backwards, but I'm not sure this is 17 | # something that would work on all tty devices ... might rather be a 18 | # VT100 feature, but not sure how to detect if the stdio is a VT100 19 | # from python ... 20 | if os.isatty(sys.stdin.fileno()): 21 | self.is_tty = True 22 | else: 23 | self.is_tty = False 24 | 25 | def add(self, name, message, total): 26 | self.message[name] = message 27 | self.total[name] = total 28 | self.count[name] = 0 29 | 30 | def update(self, name, amount): 31 | self.count[name] = self.count[name] + amount 32 | 33 | def display(self): 34 | if self.is_tty: 35 | # move back up several lines (in bash) 36 | m = "\033[<%d>A" % len(self.message) 37 | else: 38 | m = "" 39 | if self.init: 40 | m = "" 41 | self.init = False 42 | for message, total, count in zip(self.message.values(), 43 | self.total.values(), 44 | self.count.values()): 45 | m = m + "%s %d on %d\n" % (message, count, total) 46 | sys.stdout.write(m) 47 | sys.stdout.flush() 48 | 49 | 50 | # # Test 51 | # import time 52 | 53 | 54 | # def testProgressDisplay(): 55 | 56 | # d = ProgressDisplay() 57 | # d.add('m1', 'truc 1', 12) 58 | # d.add('m2', 'truc 2', 48) 59 | # d.add('m3', 'truc 3', 24) 60 | 61 | # for i in range(12): 62 | # d.update('m1', 1) 63 | # d.update('m2', 4) 64 | # d.update('m3', 2) 65 | # d.display() 66 | # time.sleep(0.1) 67 | -------------------------------------------------------------------------------- /ABXpy/__init__.py: -------------------------------------------------------------------------------- 1 | """ABX discrimination test in Python 2 | 3 | ABX discrimination is a term that is used for three stimuli presented on an 4 | ABX trial. The third is the focus. The first two stimuli (A and B) are 5 | standard, S1 and S2 in a randomly chosen order, and the subjects' task is to 6 | choose which of the two is matched by the final stimulus (X). (Glottopedia) 7 | 8 | This package contains the operations necessary to initialize, calculate and 9 | analyse the results of an ABX discrimination task. 10 | 11 | Organisation 12 | ------------ 13 | 14 | It is composed of 3 main modules and other submodules. 15 | 16 | - `task module`_ is used for creating a new task and preprocessing. 17 | 18 | - `distance package `_ is used for calculating the 19 | distances necessary for the score calculation. 20 | 21 | - `score module`_ is used for computing the score of a task. 22 | 23 | - `analyze module`_ is used for analysing the results. 24 | 25 | The features can be calculated in numpy via external tools, and made 26 | compatible with this package with the `npz2h5features 27 | `_ 28 | function. 29 | 30 | The pipeline 31 | ------------ 32 | 33 | +-------------------+----------+-----------------+ 34 | | In | Module | Out | 35 | +===================+==========+=================+ 36 | | - data.item | task | - data.abx | 37 | | - parameters | | | 38 | +-------------------+----------+-----------------+ 39 | | - data.abx | distance | - data.distance | 40 | | - data.features | | | 41 | | - distance | | | 42 | +-------------------+----------+-----------------+ 43 | | - data.abx | score | - data.score | 44 | | - data.distance | | | 45 | +-------------------+----------+-----------------+ 46 | | - data.abx | analyse | - data.csv | 47 | | - data.score | | | 48 | +-------------------+----------+-----------------+ 49 | 50 | See `Files Format `_ for a description of the files used as 51 | input and output. 52 | 53 | """ 54 | 55 | version = "0.4.3" 56 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | build-linux: 2 | stage: build 3 | script: 4 | - module load anaconda/3 5 | - conda activate abx-ci 6 | - conda env update -n abx-ci -f environment.yml 7 | - python setup.py install 8 | 9 | test-linux: 10 | stage: test 11 | script: 12 | - module load anaconda/3 13 | - conda activate abx-ci 14 | - python setup.py test 15 | 16 | # build-osx: 17 | # tags: 18 | # - macos 19 | # stage: build 20 | # script: 21 | # - conda activate abx 22 | # - conda env update -n abx -f environment.yml 23 | # - python setup.py install 24 | 25 | # test-osx: 26 | # tags: 27 | # - macos 28 | # stage: test 29 | # script: 30 | # - conda activate abx 31 | # - python setup.py test 32 | 33 | # abx package is available on oberon with "conda activate abx" 34 | deploy-oberon: 35 | stage: deploy 36 | only: 37 | refs: 38 | - master 39 | script: 40 | - module load anaconda/3 41 | - cd /shared/apps/ABXpy 42 | - git pull origin master 43 | - conda env update -n abx -f environment.yml 44 | - conda activate abx 45 | - python setup.py install 46 | - python setup.py test 47 | 48 | # documentation is available on https://docs.cognitive-ml.fr/ABXpy 49 | deploy-doc: 50 | stage: deploy 51 | only: 52 | refs: 53 | - master 54 | script: 55 | - module load anaconda/3 56 | - module load texlive/2018 57 | - conda activate abx-ci 58 | - sphinx-build doc build/doc 59 | - scp -r build/doc/* cognitive-ml.fr:/var/www/docs.cognitive-ml.fr/ABXpy 60 | 61 | # abx package available on conda with "conda install -c coml 62 | # abx". Build and upload the package only on new git tags or 63 | # manual triggers. 64 | deploy-conda-linux: 65 | stage: deploy 66 | only: 67 | - tags 68 | - triggers 69 | script: 70 | - module load anaconda/3 71 | - conda activate abx-ci 72 | - cd .conda 73 | - conda build -c coml -c conda-forge --user coml --token $CONDA_TOKEN --skip-existing . 74 | - conda build purge 75 | 76 | deploy-conda-osx: 77 | tags: 78 | - macos 79 | stage: deploy 80 | only: 81 | - tags 82 | - triggers 83 | script: 84 | - conda activate abx 85 | - cd .conda 86 | - conda build -c coml -c conda-forge --user coml --token $CONDA_TOKEN --skip-existing . 87 | - conda build purge 88 | -------------------------------------------------------------------------------- /ABXpy/verify.py: -------------------------------------------------------------------------------- 1 | """This script is used to verify the consistency of your input files. 2 | 3 | Usage 4 | ----- 5 | 6 | From the command line: 7 | 8 | .. code-block:: bash 9 | 10 | python verify.py my_data.item my_features.h5f 11 | 12 | In python: 13 | 14 | .. code-block:: python 15 | 16 | import ABXpy.verify 17 | # create a new task and compute the statistics 18 | ABXpy.verify.check('my_data.item', 'my_data.h5f') 19 | """ 20 | 21 | import argparse 22 | import h5py 23 | 24 | 25 | def check(item_file, features_file, verbose=0): 26 | """check the consistency between the item file and the features file 27 | 28 | Parameters: 29 | item_file: str 30 | the item file defining the database 31 | features_file : str 32 | the features file to be tested 33 | """ 34 | if verbose: 35 | print("Opening item file") 36 | with open(item_file) as f: 37 | cols = str.split(f.readline()) 38 | assert len(cols) >= 4, 'the syntax of the item file is incorrect' 39 | assert cols[0] == '#file', 'The first column must be named #file' 40 | assert cols[1] == 'onset', 'The second column must be named onset' 41 | assert cols[2] == 'offset', 'The third column must be named offset' 42 | assert cols[3][0] == '#', 'The fourth column must start with #' 43 | if verbose: 44 | print("Opening features file") 45 | h5f = h5py.File(features_file) 46 | files = h5f['features']['files'][:] 47 | for line in f: 48 | source = str.split(line, ' ')[0] 49 | assert source in files, ("The file {} cannot " 50 | "be found in the feature file" 51 | .format(source)) 52 | 53 | 54 | def parse_args(): 55 | parser = argparse.ArgumentParser( 56 | prog='collapse_results.py', 57 | formatter_class=argparse.RawDescriptionHelpFormatter, 58 | description='Collapse results of ABX on by conditions.', 59 | epilog="""Example usage: 60 | 61 | $ ./verify.py my_data.item my_features.h5f 62 | 63 | verify the consistency between the item file and the features file""") 64 | parser.add_argument('item', metavar='ITEM_FILE', 65 | help='database description file in .item format') 66 | parser.add_argument('features', metavar='FEATURES_FILE', 67 | help='features file in h5features format') 68 | return vars(parser.parse_args()) 69 | 70 | 71 | if __name__ == '__main__': 72 | args = parse_args() 73 | check(args['item'], args['features']) 74 | -------------------------------------------------------------------------------- /ABXpy/misc/any2h5features.py: -------------------------------------------------------------------------------- 1 | """This script contains functions to convert numpy savez file into the 2 | h5features format. 3 | 4 | The npz files must respect the following conventions: 5 | They must contains 2 arrays: 6 | 7 | - a 1D-array named 'times' 8 | - a 2D-array named 'features', the 'feature' dimension along the columns and\ 9 | the 'time' dimension along the lines 10 | 11 | """ 12 | 13 | import h5features 14 | import os 15 | import numpy as np 16 | 17 | 18 | def any_to_h5features(path, files, h5_filename, h5_groupname, 19 | batch_size=500, load=np.load): 20 | """Append a list of npz files to a h5features file. 21 | 22 | Files must have a relative name to a directory precised by the 'path' 23 | argument. 24 | 25 | Parameters 26 | ---------- 27 | path : str 28 | Path of the directory where the numpy files are stored. 29 | files : list of filename 30 | List of file to convert and append. 31 | h5_filename : filename 32 | The output h5features file. 33 | h5_groupname : str 34 | Name of the h5 group where to store the numpy files (use '/features/') 35 | for h5features files) 36 | batch_size : int 37 | Size of the writing buffer (in number of npz files). By default 500. 38 | 39 | """ 40 | features = [] 41 | times = [] 42 | internal_files = [] 43 | i = 0 44 | for f in files: 45 | if i == batch_size: 46 | h5features.write(h5_filename, h5_groupname, internal_files, times, 47 | features) 48 | features = [] 49 | times = [] 50 | internal_files = [] 51 | i = 0 52 | i = i+1 53 | data = load(os.path.join(path, f)) 54 | features.append(data['features']) 55 | times.append(data['time']) 56 | internal_files.append(os.path.splitext(f)[0]) 57 | if features: 58 | h5features.write(h5_filename, h5_groupname, internal_files, times, 59 | features) 60 | 61 | 62 | def convert(folder, h5_filename='./features.features', load=np.load): 63 | """Append a folder of numpy ndarray files in npz format into a h5features 64 | file. 65 | 66 | Parameters 67 | ---------- 68 | folder : dirname 69 | The folder containing the npz files to convert. 70 | h5_filename : filename 71 | The output h5features file. 72 | load : callable 73 | Python function that take a filepath as input and return a 74 | dictionary {'time': times, 'features': features}, (times and 75 | features being array-like containing respectively the centered 76 | times of the frames and the features) 77 | 78 | """ 79 | files = os.listdir(folder) 80 | any_to_h5features(folder, files, h5_filename, '/features/', load=load) 81 | 82 | 83 | if __name__ == '__main__': 84 | import argparse 85 | 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument('folder', help='folder containing the files ' 88 | 'to be converted') 89 | parser.add_argument('h5_filename', 90 | help='desired path for the h5features file') 91 | args = parser.parse_args() 92 | convert(args.folder, args.h5_filename) 93 | -------------------------------------------------------------------------------- /ABXpy/distances/metrics/dtw/dtw.pyx: -------------------------------------------------------------------------------- 1 | """Created on Tue Jan 21 14:15:44 2014 2 | 3 | @author: Thomas Schatz adapted from Gabriel Synaeve's code 4 | 5 | The "feature" dimension is along the columns and the "time" dimension 6 | along the lines of arrays x and y. 7 | 8 | The function do not verify its arguments, common problems are: 9 | shape of one array is n instead of (n,1) 10 | an array is not of the correct type DTYPE_t 11 | the feature dimension of the two array do not match 12 | the feature and time dimension are exchanged 13 | the dist_array is not of the correct size or type 14 | 15 | """ 16 | 17 | import numpy as np 18 | cimport numpy as np 19 | cimport cython 20 | from cpython cimport bool 21 | ctypedef np.float64_t CTYPE_t # cost type 22 | ctypedef np.intp_t IND_t # array index type 23 | CTYPE = np.float64 # cost type 24 | 25 | 26 | def dtw(x, y, metric, normalized): 27 | if x.shape[0] == 0 or y.shape[0] == 0: 28 | raise ValueError('Cannot compute distance between empty representations') 29 | else: 30 | return _dtw(x.shape[0], y.shape[0], metric(x,y), normalized) 31 | 32 | 33 | # There was a bug at initialization in both Dan Ellis DTW and Gabriel's code: 34 | # Dan Ellis: do not take into account distance between the first frame of x and the first frame of y 35 | # Gabriel: init cost[0,:] and cost[:,0] by dist_array[0,:], resp. dist_array[:,0] instead of their cumsum 36 | #FIXME retest negligeability of min ? 37 | cpdef _dtw(IND_t N, IND_t M, CTYPE_t[:,:] dist_array, bool normalized): 38 | cdef IND_t i, j 39 | cdef CTYPE_t[:,:] cost = np.empty((N, M), dtype=CTYPE) 40 | cdef CTYPE_t final_cost, c_diag, c_left, c_up 41 | # initialization 42 | cost[0,0] = dist_array[0,0] 43 | for i in range(1,N): 44 | cost[i,0] = dist_array[i,0] + cost[i-1,0] 45 | for j in range(1,M): 46 | cost[0,j] = dist_array[0,j] + cost[0,j-1] 47 | # the dynamic programming loop 48 | for i in range(1,N): 49 | for j in range(1,M): 50 | cost[i,j] = dist_array[i,j] + min(cost[i-1,j], cost[i-1,j-1], cost[i,j-1]) 51 | 52 | final_cost = cost[N-1, M-1] 53 | if normalized: 54 | path_len = 1 55 | i = N-1 56 | j = M-1 57 | while i > 0 and j > 0: 58 | c_up = cost[i-1, j] 59 | c_left = cost[i, j-1] 60 | c_diag = cost[i-1, j-1] 61 | if c_diag <= c_left and c_diag <= c_up: 62 | i -= 1 63 | j -= 1 64 | elif c_left <= c_up: 65 | j -= 1 66 | else: 67 | i -= 1 68 | path_len += 1 69 | if i == 0: 70 | path_len += j 71 | if j == 0: 72 | path_len += i 73 | final_cost /= path_len 74 | return final_cost 75 | 76 | 77 | # import numpy as np 78 | # cimport numpy as np 79 | # cimport cython 80 | # ctypedef np.float64_t DTYPE_t # feature type (could be int) 81 | # ctypedef np.float64_t CTYPE_t # cost type 82 | # ctypedef np.intp_t IND_t # array index type 83 | # CTYPE = np.float64 # cost type 84 | 85 | # cpdef DTW(DTYPE_t[:,:] x, DTYPE_t[:,:] y, CTYPE_t[:,:] dist_array): 86 | # cdef IND_t N = x.shape[0] 87 | # cdef IND_t M = y.shape[0] 88 | # cdef IND_t K = x.shape[1] 89 | # cdef IND_t i, j 90 | # cdef CTYPE_t[:,:] cost = np.empty((N, M), dtype=CTYPE) 91 | # # initialization 92 | # cost[:,0] = dist_array[:,0] 93 | # cost[0,:] = dist_array[0,:] 94 | # # the dynamic programming loop 95 | # for i in range(1, N): 96 | # for j in range(1, M): 97 | # cost[i,j] = dist_array[i,j] + min(cost[i-1,j], cost[i-1,j-1], cost[i,j-1]) 98 | # return cost[N-1,M-1] 99 | -------------------------------------------------------------------------------- /ABXpy/misc/items.py: -------------------------------------------------------------------------------- 1 | """Generate test files for the test functions""" 2 | 3 | import filecmp 4 | import h5features 5 | import numpy as np 6 | import os 7 | import pandas 8 | import subprocess 9 | 10 | 11 | def generate_testitems(base, n, repeats=0, name='data.item'): 12 | """Minimal item file generator for task.py""" 13 | dim_i, dim_j = base ** n * (repeats + 1) + 1, n + 4 14 | 15 | def fun(i, j): 16 | if i == 0: 17 | if j < 4: 18 | return ['#file', 'onset', 'offset', '#src'][j] 19 | else: 20 | return 'c%s' % (j - 4) 21 | elif j < 4: 22 | return 'snfi'[j] + str(i - 1) 23 | else: 24 | i -= 1 25 | j -= 4 26 | return i // (base ** j) % base 27 | 28 | res = [[fun(i, j) for j in range(dim_j)] 29 | for i in range(dim_i)] 30 | np.savetxt(name, res, delimiter=' ', fmt='%s') 31 | 32 | 33 | def generate_named_testitems(base, n, repeats=0, name='data.item'): 34 | """Extended item file generator 35 | """ 36 | 37 | dim_i, dim_j = base ** n * (repeats + 1) + 1, n + 4 38 | 39 | def fun(i, j): 40 | if i == 0: 41 | if j < 4: 42 | return ['#file', 'onset', 'offset', '#item'][j] 43 | else: 44 | return 'c%s' % (j - 4) 45 | elif j == 0: 46 | return 's%s' % (i - 1) 47 | elif j in [1, 2]: 48 | return 0 49 | elif j == 3: 50 | return 'i%s' % (i-1) 51 | else: 52 | i -= 1 53 | j -= 4 54 | return 'c%s_v%s' % (j, i // (base ** j) % base) 55 | 56 | res = [[fun(i, j) for j in range(dim_j)] 57 | for i in range(dim_i)] 58 | np.savetxt(name, res, delimiter=' ', fmt='%s') 59 | 60 | 61 | def generate_features(n_files, n_feat=2, max_frames=3, name='data.features'): 62 | """Random feature file generator 63 | """ 64 | if os.path.exists(name): 65 | os.remove(name) 66 | features = [] 67 | times = [] 68 | files = [] 69 | for i in range(n_files): 70 | n_frames = np.random.randint(max_frames) + 1 71 | features.append(np.random.randn(n_frames, n_feat)) 72 | times.append(np.linspace(0, 1, n_frames)) 73 | files.append('s%d' % i) 74 | h5features.write(name, 'features', files, times, features) 75 | 76 | 77 | def generate_db_and_feat(base, n, repeats=0, name_db='data.item', n_feat=2, 78 | max_frames=3, name_feat='data.features'): 79 | """Item and feature files generator 80 | """ 81 | generate_named_testitems(base, n, repeats, name_db) 82 | print(name_db) 83 | n_files = (base ** n) * (repeats + 1) 84 | generate_features(n_files, n_feat, max_frames, name_feat) 85 | 86 | 87 | def cmp(f1, f2): 88 | return filecmp.cmp(f1, f2, shallow=True) 89 | 90 | 91 | def h5cmp(f1, f2): 92 | try: 93 | out = subprocess.check_output(['h5diff', f1, f2]) 94 | except subprocess.CalledProcessError: 95 | return False 96 | if out: 97 | return False 98 | else: 99 | return True 100 | 101 | 102 | def csv_cmp(f1, f2, sep='\t'): 103 | """Returns True if the 2 CSV files are equals 104 | 105 | The comparison is not sensitive to the columns order. 106 | 107 | """ 108 | csv1 = pandas.read_csv(f1, sep=sep) 109 | csv1 = csv1.reindex(sorted(csv1.columns), axis=1) 110 | 111 | csv2 = pandas.read_csv(f2, sep=sep) 112 | csv2 = csv2.reindex(sorted(csv2.columns), axis=1) 113 | 114 | return csv1.to_csv().split('\n').sort() == csv2.to_csv().split('\n').sort() 115 | -------------------------------------------------------------------------------- /test/test_sampling.py: -------------------------------------------------------------------------------- 1 | """This test script contains the tests for the sampling module and its use 2 | with task.py""" 3 | 4 | import random 5 | import warnings 6 | 7 | from scipy.stats import chisquare as chisquare 8 | import numpy as np 9 | 10 | import ABXpy.sampling as sampling 11 | 12 | 13 | # FIXME problems when K > N/2 (not important ?) 14 | def chi2test(frequencies, significance): 15 | # dof = len(frequencies) - 1 16 | dof = 0 17 | (_, p) = chisquare(frequencies, ddof=dof) 18 | return p > significance 19 | 20 | 21 | def _test_no_replace(N, K): 22 | """Test the correct functionnality of the sampler, and particularly the no 23 | replacement property in simple sample""" 24 | sampler = sampling.sampler.IncrementalSampler(N, K) 25 | indices = sampler.sample(N) 26 | indices_test = np.array(list(set(indices))) 27 | assert len(indices_test) == len(indices) 28 | assert len(indices) == K 29 | 30 | 31 | def _test_completion(N, K, n): 32 | """Test the exact completion of the sample when N is a multiple of N 33 | """ 34 | sampler = sampling.sampler.IncrementalSampler(N, K) 35 | count = 0 36 | for j in range(N // n): 37 | indices = sampler.sample(n) 38 | count += len(indices) 39 | indices = sampler.sample(N % n) 40 | count += len(indices) 41 | assert count == K 42 | 43 | 44 | # this function is really not optimised 45 | def _test_uniformity(N, K, n, nbins=10, significance=0.001): 46 | """Test the uniformity of a sample 47 | 48 | .. note:: This test is not exact and may return false even if the function 49 | is correct. Use the significance wisely. 50 | 51 | Parameters: 52 | ----------- 53 | nbins : int 54 | the number of bins for the Chi2 test 55 | """ 56 | sampler = sampling.sampler.IncrementalSampler(N, K) 57 | distr = [] 58 | bins = np.zeros(nbins, np.int64) 59 | for j in range(N / n): 60 | indices = sampler.sample(n) + n * j 61 | distr.extend(indices.tolist()) 62 | for i in distr: 63 | bins[i * nbins / N] += 1 64 | assert chi2test(bins, significance) 65 | 66 | 67 | def test_simple_completion(): 68 | for i in range(1000): 69 | N = random.randint(1000, 10000) 70 | with warnings.catch_warnings(): 71 | warnings.simplefilter("ignore") 72 | _test_completion(N, K=random.randrange(100, N // 2), 73 | n=random.randrange(50, N)) 74 | 75 | 76 | def test_simple_no_replace(): 77 | for i in range(100): 78 | N = random.randint(1000, 10000) 79 | _test_no_replace(N, random.randint(100, N // 2)) 80 | 81 | 82 | def test_hard_completion(): 83 | for i in range(3): 84 | N = random.randint(10 ** 6, 10 ** 7) 85 | _test_completion(N, K=random.randrange(10 ** 5, N // 2), 86 | n=random.randrange(10 ** 5, N)) 87 | 88 | 89 | def test_hard_no_replace(): 90 | for i in range(3): 91 | N = random.randint(10 ** 6, 10 ** 7) 92 | _test_no_replace(N, K=random.randrange(10 ** 5, N // 2)) 93 | 94 | 95 | def test_simple_uniformity(): 96 | for i in range(100): 97 | N = random.randint(1000, 10000) 98 | _test_completion(N, K=random.randrange(100, N // 2), 99 | n=random.randrange(50, N)) 100 | 101 | 102 | # import matplotlib.pyplot as plt 103 | # 104 | # 105 | # def plot_uniformity(nb_resamples, N, K): 106 | # indices = [] 107 | # for i in range(nb_resamples): 108 | # if i % 1000 == 0: 109 | # print('%d resamples left to do' % (nb_resamples - i)) 110 | # sampler = sampling.sampler.IncrementalSampler(N, K) 111 | # current_N = 0 112 | # while current_N < N: 113 | # n = min(random.randrange(N / 10), N - current_N) 114 | # indices = indices + list(sampler.sample(n) + current_N) 115 | # current_N = current_N + n 116 | # plt.hist(indices, bins=100) 117 | 118 | 119 | # test_simple_uniformity() 120 | # test_hard_completion() 121 | # test_hard_no_replace() 122 | # test_simple_no_replace() 123 | # test_simple_completion() 124 | # plot_uniformity(10**4, 10**3, 10) 125 | -------------------------------------------------------------------------------- /ABXpy/database/database.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas 3 | import numpy 4 | import ABXpy.misc.tinytree as tinytree 5 | 6 | # FIXME use just one isolated | as a separator instead of two # 7 | 8 | 9 | # custom read_table that ignore empty entries at the end of a file (they 10 | # can result from trailing white spaces at the end for example) 11 | def read_table(filename): 12 | db = pandas.read_csv(filename, sep='[ \t]+', engine='python') 13 | # removes row with all null values (None or NaN...) 14 | db = db.dropna(how='all') 15 | return db 16 | 17 | 18 | # function that loads a database 19 | def load(filename, features_info=False): 20 | # reading the main database using pandas (it is now a DataFrame) 21 | ext = '.item' 22 | if not(filename[len(filename) - len(ext):] == ext): 23 | filename = filename + ext 24 | db = read_table(filename) 25 | 26 | # finding '#' (to separate location info from attribute info) and fixing 27 | # names of columns 28 | columns = db.columns.tolist() 29 | 30 | # check the 3 first columns are '#file', 'onset', 'offset' 31 | assert ' '.join(columns[:3]) == '#file onset offset', ( 32 | 'The first 3 columns of the item file must be "#file onset offset"' 33 | 'They are "{}"'.format(' '.join(columns[:3]))) 34 | 35 | l = [] 36 | for i, c in enumerate(columns): 37 | if c[0] == "#": 38 | l.append(i) 39 | columns[i] = c[1:] 40 | db.columns = pandas.Index(columns) 41 | assert len(l) > 0 and l[0] == 0, ( 42 | 'The first column name in the database main file must be ' 43 | 'prefixed with # (sharp)') 44 | assert len(l) == 2, ( 45 | 'Exactly two column names in the database main file must be' 46 | ' prefixed with a # (sharp)') 47 | feat_db = db[db.columns[:l[1]]] 48 | db = db[db.columns[l[1]:]] 49 | # verbose print(" Read input File '"+filename+"'. Defined conditions: 50 | # "+str(newcolumns[attrI:len(columns)])) 51 | 52 | # opening up existing auxiliary files, and merging with the main database 53 | # and creating a forest describing the hierarchy at the same time (useful 54 | # for optimizing regressor generation and filtering) 55 | 56 | (basename, _) = os.path.splitext(filename) 57 | db, db_hierarchy = load_aux_dbs(basename, db, db.columns, filename) 58 | 59 | # dealing with missing items: for now rows with missing items are dropped 60 | nanrows = numpy.any(pandas.isnull(db), 1) 61 | if any(nanrows): 62 | dropped = db[nanrows] 63 | dropped.to_csv(basename + '-removed' + ext) 64 | db = db[~ nanrows] 65 | feat_db = feat_db[~ nanrows] 66 | # not so verbose print('** Warning ** ' + len(nanrows) + ' items were 67 | # removed because of missing information. The removed items are listed in' 68 | # + basename + '-removed.item' 69 | if features_info: 70 | return db, db_hierarchy, feat_db 71 | else: 72 | return db, db_hierarchy 73 | 74 | 75 | # recursive auxiliary function for loading the auxiliary databases 76 | def load_aux_dbs(basename, db, cols, mainfile): 77 | forest = [tinytree.Tree() for col in cols] 78 | for i, col in enumerate(cols): 79 | forest[i].name = col 80 | try: 81 | auxfile = basename + '.' + col 82 | if not(auxfile == mainfile): 83 | auxdb = read_table(auxfile) 84 | assert col == auxdb.columns[0], ( 85 | 'First column name in file %s' 86 | ' is %s. It should be %s instead.' % ( 87 | auxfile, auxdb.columns[0], col)) 88 | # call get_aux_dbs on child columns 89 | auxdb, auxforest = load_aux_dbs( 90 | basename, auxdb, auxdb.columns[1:]) 91 | # add to forest 92 | forest[i].addChildrenFromList(auxforest) 93 | # merging the databases 94 | db = pandas.merge(db, auxdb, on=col, how='left') 95 | # verbose print(" Read auxiliary File '"+auxfile+"'. Defined 96 | # conditions: "+str(newcol[1:len(newcol)])+" on key '"+newcol[0]+"'") 97 | except IOError: 98 | pass 99 | return db, forest 100 | -------------------------------------------------------------------------------- /ABXpy/distances/metrics/kullback_leibler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def kl_ptwise(x, y): 5 | return np.sum(x * (np.log(x) - np.log(y))) 6 | 7 | 8 | def js_ptwise(x, y): 9 | m = (x + y) / 2 10 | return 0.5 * kl_ptwise(x, m) + 0.5 * kl_ptwise(y, m) 11 | 12 | 13 | def __kl_divergence(x, y): 14 | """ just the KL-div """ 15 | pq = np.dot(x, np.log(y.transpose())) 16 | pp = np.tile( 17 | np.sum(x * np.log(x), axis=1).reshape(x.shape[0], 1), (1, y.shape[0])) 18 | return pp - pq 19 | 20 | 21 | def kl_divergence(x, y, thresholded=True, symmetrized=True, normalize=True): 22 | """Kullback-Leibler divergence 23 | 24 | x and y should be 2D numpy arrays with "times" on the lines and 25 | "features" on the columns 26 | 27 | - thresholded=True => means we add an epsilon to all the dimensions/values 28 | AND renormalize inputs. 29 | 30 | - symmetrized=True => uses the symmetrized KL (0.5 x->y + 0.5 y->x). 31 | 32 | - normalize=True => normalize the inputs so that lines sum to one. 33 | 34 | """ 35 | assert (x.dtype == np.float64 and y.dtype == np.float64) or ( 36 | x.dtype == np.float32 and y.dtype == np.float32) 37 | # assert (np.all(x.sum(1) != 0.) and np.all(y.sum(1) != 0.)) 38 | if thresholded: 39 | normalize = True 40 | if normalize: 41 | x /= x.sum(1).reshape(x.shape[0], 1) 42 | y /= y.sum(1).reshape(y.shape[0], 1) 43 | if thresholded: 44 | eps = np.finfo(x.dtype).eps 45 | x = x + eps 46 | y = y + eps 47 | x /= x.sum(1).reshape(x.shape[0], 1) 48 | y /= y.sum(1).reshape(y.shape[0], 1) 49 | res = __kl_divergence(x, y) 50 | 51 | if symmetrized: 52 | res = 0.5 * res + 0.5 * __kl_divergence(y, x).transpose() 53 | 54 | return np.float64(res).reshape(res.shape) 55 | 56 | 57 | def js_divergence(x, y, normalize=True): 58 | """Jensen-Shannon divergence 59 | 60 | x and y should be 2D numpy arrays with "times" on the lines and 61 | "features" on the columns - normalize=True => normalize the inputs 62 | so that lines sum to one. 63 | 64 | """ 65 | assert (x.dtype == np.float64 and y.dtype == np.float64) or ( 66 | x.dtype == np.float32 and y.dtype == np.float32) 67 | assert (np.all(x.sum(1) != 0.) and np.all(y.sum(1) != 0.)) 68 | 69 | if normalize: 70 | x /= x.sum(1).reshape(x.shape[0], 1) 71 | y /= y.sum(1).reshape(y.shape[0], 1) 72 | xx = np.tile(x, (y.shape[0], 1, 1)).transpose((1, 0, 2)) 73 | yy = np.tile(y, (x.shape[0], 1, 1)) 74 | m = (xx + yy) / 2 75 | x_m = np.sum(xx * np.log(m), axis=2) 76 | y_m = np.sum(yy * np.log(m), axis=2) 77 | x_x = np.tile( 78 | np.sum(x * np.log(x), axis=1).reshape(x.shape[0], 1), (1, y.shape[0])) 79 | y_y = np.tile(np.sum( 80 | y * np.log(y), axis=1).reshape(y.shape[0], 1), 81 | (1, x.shape[0])).transpose() 82 | res = 0.5 * (x_x - x_m) + 0.5 * (y_y - y_m) 83 | return np.float64(res) 84 | # division by zero 85 | 86 | 87 | def sqrt_js_divergence(x, y): 88 | return np.sqrt(js_divergence(x, y)) 89 | 90 | 91 | def hellinger_distance(x, y): 92 | """Hellinger distance 93 | 94 | x and y should be 2D numpy arrays with "times" on the lines and 95 | "features" on the columns - normalize=True => normalize the inputs 96 | so that lines sum to one. 97 | 98 | """ 99 | assert (x.dtype == np.float64 and y.dtype == np.float64) or ( 100 | x.dtype == np.float32 and y.dtype == np.float32) 101 | assert (np.all(x.sum(1) != 0.) and np.all(y.sum(1) != 0.)) 102 | x /= x.sum(1).reshape(x.shape[0], 1) 103 | y /= y.sum(1).reshape(y.shape[0], 1) 104 | x = np.sqrt(x) 105 | y = np.sqrt(y) 106 | # x (120, 40), y (100, 40), H(x,y) (120, 100) 107 | xx = np.tile(x, (y.shape[0], 1, 1)).transpose((1, 0, 2)) 108 | yy = np.tile(y, (x.shape[0], 1, 1)) 109 | xx_yy = xx - yy 110 | res = np.sqrt(np.sum(xx_yy ** 2, axis=-1)) 111 | return np.float64((1. / np.sqrt(2)) * res) 112 | 113 | 114 | def is_distance(x, y): 115 | """Itakura-Saito distance 116 | 117 | x and y should be 2D numpy arrays with "times" on the lines and 118 | "features" on the columns 119 | 120 | """ 121 | assert (x.dtype == np.float64 and y.dtype == np.float64) or ( 122 | x.dtype == np.float32 and y.dtype == np.float32) 123 | 124 | # TODO 125 | raise NotImplementedError 126 | -------------------------------------------------------------------------------- /test/frozen_files/data.csv: -------------------------------------------------------------------------------- 1 | c1_1 c0_1 c0_2 c1_2 by score n 2 | c1_v0 c0_v0 c0_v1 c1_v1 c2_v0 0.375 8 3 | c1_v0 c0_v0 c0_v1 c1_v2 c2_v0 0.0 8 4 | c1_v0 c0_v0 c0_v2 c1_v1 c2_v0 0.5 8 5 | c1_v0 c0_v0 c0_v2 c1_v2 c2_v0 0.25 8 6 | c1_v0 c0_v1 c0_v0 c1_v1 c2_v0 1.0 8 7 | c1_v0 c0_v1 c0_v0 c1_v2 c2_v0 0.25 8 8 | c1_v0 c0_v1 c0_v2 c1_v1 c2_v0 0.5 8 9 | c1_v0 c0_v1 c0_v2 c1_v2 c2_v0 0.375 8 10 | c1_v0 c0_v2 c0_v0 c1_v1 c2_v0 0.625 8 11 | c1_v0 c0_v2 c0_v0 c1_v2 c2_v0 0.875 8 12 | c1_v0 c0_v2 c0_v1 c1_v1 c2_v0 0.5 8 13 | c1_v0 c0_v2 c0_v1 c1_v2 c2_v0 0.5 8 14 | c1_v1 c0_v0 c0_v1 c1_v0 c2_v0 1.0 8 15 | c1_v1 c0_v0 c0_v1 c1_v2 c2_v0 0.25 8 16 | c1_v1 c0_v0 c0_v2 c1_v0 c2_v0 0.625 8 17 | c1_v1 c0_v0 c0_v2 c1_v2 c2_v0 0.625 8 18 | c1_v1 c0_v1 c0_v0 c1_v0 c2_v0 0.5 8 19 | c1_v1 c0_v1 c0_v0 c1_v2 c2_v0 0.5 8 20 | c1_v1 c0_v1 c0_v2 c1_v0 c2_v0 0.5 8 21 | c1_v1 c0_v1 c0_v2 c1_v2 c2_v0 0.25 8 22 | c1_v1 c0_v2 c0_v0 c1_v0 c2_v0 0.5 8 23 | c1_v1 c0_v2 c0_v0 c1_v2 c2_v0 0.5 8 24 | c1_v1 c0_v2 c0_v1 c1_v0 c2_v0 0.25 8 25 | c1_v1 c0_v2 c0_v1 c1_v2 c2_v0 0.0 8 26 | c1_v2 c0_v0 c0_v1 c1_v0 c2_v0 0.0 8 27 | c1_v2 c0_v0 c0_v1 c1_v1 c2_v0 1.0 8 28 | c1_v2 c0_v0 c0_v2 c1_v0 c2_v0 0.625 8 29 | c1_v2 c0_v0 c0_v2 c1_v1 c2_v0 0.75 8 30 | c1_v2 c0_v1 c0_v0 c1_v0 c2_v0 0.25 8 31 | c1_v2 c0_v1 c0_v0 c1_v1 c2_v0 0.0 8 32 | c1_v2 c0_v1 c0_v2 c1_v0 c2_v0 0.125 8 33 | c1_v2 c0_v1 c0_v2 c1_v1 c2_v0 0.0 8 34 | c1_v2 c0_v2 c0_v0 c1_v0 c2_v0 0.5 8 35 | c1_v2 c0_v2 c0_v0 c1_v1 c2_v0 0.5 8 36 | c1_v2 c0_v2 c0_v1 c1_v0 c2_v0 0.625 8 37 | c1_v2 c0_v2 c0_v1 c1_v1 c2_v0 0.625 8 38 | c1_v0 c0_v0 c0_v1 c1_v1 c2_v1 0.5 8 39 | c1_v0 c0_v0 c0_v1 c1_v2 c2_v1 0.5 8 40 | c1_v0 c0_v0 c0_v2 c1_v1 c2_v1 0.75 8 41 | c1_v0 c0_v0 c0_v2 c1_v2 c2_v1 0.5 8 42 | c1_v0 c0_v1 c0_v0 c1_v1 c2_v1 0.0 8 43 | c1_v0 c0_v1 c0_v0 c1_v2 c2_v1 0.75 8 44 | c1_v0 c0_v1 c0_v2 c1_v1 c2_v1 0.25 8 45 | c1_v0 c0_v1 c0_v2 c1_v2 c2_v1 0.5 8 46 | c1_v0 c0_v2 c0_v0 c1_v1 c2_v1 0.5 8 47 | c1_v0 c0_v2 c0_v0 c1_v2 c2_v1 1.0 8 48 | c1_v0 c0_v2 c0_v1 c1_v1 c2_v1 0.5 8 49 | c1_v0 c0_v2 c0_v1 c1_v2 c2_v1 0.375 8 50 | c1_v1 c0_v0 c0_v1 c1_v0 c2_v1 0.375 8 51 | c1_v1 c0_v0 c0_v1 c1_v2 c2_v1 0.5 8 52 | c1_v1 c0_v0 c0_v2 c1_v0 c2_v1 0.5 8 53 | c1_v1 c0_v0 c0_v2 c1_v2 c2_v1 0.5 8 54 | c1_v1 c0_v1 c0_v0 c1_v0 c2_v1 0.25 8 55 | c1_v1 c0_v1 c0_v0 c1_v2 c2_v1 0.0 8 56 | c1_v1 c0_v1 c0_v2 c1_v0 c2_v1 0.25 8 57 | c1_v1 c0_v1 c0_v2 c1_v2 c2_v1 0.25 8 58 | c1_v1 c0_v2 c0_v0 c1_v0 c2_v1 0.625 8 59 | c1_v1 c0_v2 c0_v0 c1_v2 c2_v1 0.625 8 60 | c1_v1 c0_v2 c0_v1 c1_v0 c2_v1 0.375 8 61 | c1_v1 c0_v2 c0_v1 c1_v2 c2_v1 0.375 8 62 | c1_v2 c0_v0 c0_v1 c1_v0 c2_v1 0.75 8 63 | c1_v2 c0_v0 c0_v1 c1_v1 c2_v1 0.375 8 64 | c1_v2 c0_v0 c0_v2 c1_v0 c2_v1 0.875 8 65 | c1_v2 c0_v0 c0_v2 c1_v1 c2_v1 0.75 8 66 | c1_v2 c0_v1 c0_v0 c1_v0 c2_v1 0.625 8 67 | c1_v2 c0_v1 c0_v0 c1_v1 c2_v1 0.5 8 68 | c1_v2 c0_v1 c0_v2 c1_v0 c2_v1 0.625 8 69 | c1_v2 c0_v1 c0_v2 c1_v1 c2_v1 0.0 8 70 | c1_v2 c0_v2 c0_v0 c1_v0 c2_v1 0.625 8 71 | c1_v2 c0_v2 c0_v0 c1_v1 c2_v1 0.5 8 72 | c1_v2 c0_v2 c0_v1 c1_v0 c2_v1 0.75 8 73 | c1_v2 c0_v2 c0_v1 c1_v1 c2_v1 0.5 8 74 | c1_v0 c0_v0 c0_v1 c1_v1 c2_v2 0.75 8 75 | c1_v0 c0_v0 c0_v1 c1_v2 c2_v2 0.625 8 76 | c1_v0 c0_v0 c0_v2 c1_v1 c2_v2 0.75 8 77 | c1_v0 c0_v0 c0_v2 c1_v2 c2_v2 0.25 8 78 | c1_v0 c0_v1 c0_v0 c1_v1 c2_v2 0.75 8 79 | c1_v0 c0_v1 c0_v0 c1_v2 c2_v2 0.25 8 80 | c1_v0 c0_v1 c0_v2 c1_v1 c2_v2 0.75 8 81 | c1_v0 c0_v1 c0_v2 c1_v2 c2_v2 0.625 8 82 | c1_v0 c0_v2 c0_v0 c1_v1 c2_v2 0.5 8 83 | c1_v0 c0_v2 c0_v0 c1_v2 c2_v2 1.0 8 84 | c1_v0 c0_v2 c0_v1 c1_v1 c2_v2 0.625 8 85 | c1_v0 c0_v2 c0_v1 c1_v2 c2_v2 0.875 8 86 | c1_v1 c0_v0 c0_v1 c1_v0 c2_v2 0.375 8 87 | c1_v1 c0_v0 c0_v1 c1_v2 c2_v2 0.75 8 88 | c1_v1 c0_v0 c0_v2 c1_v0 c2_v2 0.75 8 89 | c1_v1 c0_v0 c0_v2 c1_v2 c2_v2 0.625 8 90 | c1_v1 c0_v1 c0_v0 c1_v0 c2_v2 0.625 8 91 | c1_v1 c0_v1 c0_v0 c1_v2 c2_v2 0.125 8 92 | c1_v1 c0_v1 c0_v2 c1_v0 c2_v2 0.75 8 93 | c1_v1 c0_v1 c0_v2 c1_v2 c2_v2 0.375 8 94 | c1_v1 c0_v2 c0_v0 c1_v0 c2_v2 0.5 8 95 | c1_v1 c0_v2 c0_v0 c1_v2 c2_v2 0.75 8 96 | c1_v1 c0_v2 c0_v1 c1_v0 c2_v2 0.625 8 97 | c1_v1 c0_v2 c0_v1 c1_v2 c2_v2 0.625 8 98 | c1_v2 c0_v0 c0_v1 c1_v0 c2_v2 0.25 8 99 | c1_v2 c0_v0 c0_v1 c1_v1 c2_v2 0.375 8 100 | c1_v2 c0_v0 c0_v2 c1_v0 c2_v2 0.625 8 101 | c1_v2 c0_v0 c0_v2 c1_v1 c2_v2 0.625 8 102 | c1_v2 c0_v1 c0_v0 c1_v0 c2_v2 0.75 8 103 | c1_v2 c0_v1 c0_v0 c1_v1 c2_v2 0.625 8 104 | c1_v2 c0_v1 c0_v2 c1_v0 c2_v2 0.875 8 105 | c1_v2 c0_v1 c0_v2 c1_v1 c2_v2 0.5 8 106 | c1_v2 c0_v2 c0_v0 c1_v0 c2_v2 0.5 8 107 | c1_v2 c0_v2 c0_v0 c1_v1 c2_v2 0.625 8 108 | c1_v2 c0_v2 c0_v1 c1_v0 c2_v2 0.5 8 109 | c1_v2 c0_v2 c0_v1 c1_v1 c2_v2 0.625 8 110 | -------------------------------------------------------------------------------- /test/test_analyse.py: -------------------------------------------------------------------------------- 1 | """This test script contains tests for analyze.py""" 2 | 3 | import os 4 | import shutil 5 | 6 | import ABXpy.task 7 | import ABXpy.distances.distances as distances 8 | import ABXpy.distances.metrics.cosine as cosine 9 | import ABXpy.distances.metrics.dtw as dtw 10 | import ABXpy.score as score 11 | import ABXpy.misc.items as items 12 | import ABXpy.analyze as analyze 13 | import numpy as np 14 | 15 | 16 | frozen_folder = os.path.join( 17 | os.path.dirname(os.path.realpath(__file__)), 'frozen_files') 18 | 19 | 20 | def frozen_file(ext): 21 | return os.path.join(frozen_folder, 'data') + '.' + ext 22 | 23 | 24 | def dtw_cosine_distance(x, y, normalized): 25 | return dtw.dtw(x, y, cosine.cosine_distance, normalized) 26 | 27 | 28 | def test_analyze(): 29 | try: 30 | if not os.path.exists('test_items'): 31 | os.makedirs('test_items') 32 | item_file = 'test_items/data.item' 33 | feature_file = 'test_items/data.features' 34 | distance_file = 'test_items/data.distance' 35 | scorefilename = 'test_items/data.score' 36 | taskfilename = 'test_items/data.abx' 37 | analyzefilename = 'test_items/data.csv' 38 | 39 | items.generate_db_and_feat(3, 3, 1, item_file, 2, 3, feature_file) 40 | task = ABXpy.task.Task(item_file, 'c0', 'c1', 'c2') 41 | task.generate_triplets(taskfilename) 42 | distances.compute_distances( 43 | feature_file, '/features/', taskfilename, 44 | distance_file, dtw_cosine_distance, 45 | normalized=True, n_cpu=1) 46 | score.score(taskfilename, distance_file, scorefilename) 47 | analyze.analyze(taskfilename, scorefilename, analyzefilename) 48 | finally: 49 | shutil.rmtree('test_items', ignore_errors=True) 50 | 51 | 52 | def test_threshold_analyze(): 53 | try: 54 | if not os.path.exists('test_items'): 55 | os.makedirs('test_items') 56 | item_file = 'test_items/data.item' 57 | feature_file = 'test_items/data.features' 58 | distance_file = 'test_items/data.distance' 59 | scorefilename = 'test_items/data.score' 60 | taskfilename = 'test_items/data.abx' 61 | analyzefilename = 'test_items/data.csv' 62 | threshold = 2 63 | 64 | items.generate_db_and_feat(3, 3, 1, item_file, 2, 3, feature_file) 65 | task = ABXpy.task.Task(item_file, 'c0', 'c1', 'c2') 66 | task.generate_triplets(taskfilename, threshold=threshold) 67 | distances.compute_distances( 68 | feature_file, '/features/', taskfilename, 69 | distance_file, dtw_cosine_distance, 70 | normalized=True, n_cpu=1) 71 | score.score(taskfilename, distance_file, scorefilename) 72 | analyze.analyze(taskfilename, scorefilename, analyzefilename) 73 | number_triplets = np.loadtxt(analyzefilename, dtype=int, 74 | delimiter='\t', skiprows=1, usecols=[-1]) 75 | assert np.all(number_triplets == threshold) 76 | finally: 77 | shutil.rmtree('test_items', ignore_errors=True) 78 | 79 | 80 | def test_frozen_analyze(): 81 | """Frozen analyze compare the results of a previously "frozen" run with 82 | a new one, asserting that the code did not change in behaviour. 83 | """ 84 | try: 85 | if not os.path.exists('test_items'): 86 | os.makedirs('test_items') 87 | item_file = frozen_file('item') 88 | feature_file = frozen_file('features') 89 | distance_file = 'test_items/data.distance' 90 | scorefilename = 'test_items/data.score' 91 | taskfilename = 'test_items/data.abx' 92 | analyzefilename = 'test_items/data.csv' 93 | 94 | task = ABXpy.task.Task(item_file, 'c0', 'c1', 'c2') 95 | task.generate_triplets(taskfilename) 96 | distances.compute_distances( 97 | feature_file, '/features/', taskfilename, 98 | distance_file, dtw_cosine_distance, 99 | normalized=True, n_cpu=1) 100 | score.score(taskfilename, distance_file, scorefilename) 101 | analyze.analyze(taskfilename, scorefilename, analyzefilename) 102 | 103 | # assert items.h5cmp(taskfilename, frozen_file('abx')) 104 | # assert items.h5cmp(distance_file, frozen_file('distance')) 105 | # assert items.h5cmp(scorefilename, frozen_file('score')) 106 | assert items.csv_cmp(analyzefilename, frozen_file('csv')) 107 | 108 | finally: 109 | shutil.rmtree('test_items', ignore_errors=True) 110 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://travis-ci.org/bootphon/ABXpy.svg?branch=master 2 | :target: https://travis-ci.org/bootphon/ABXpy 3 | .. image:: https://codecov.io/gh/bootphon/ABXpy/branch/master/graph/badge.svg 4 | :target: https://codecov.io/gh/bootphon/ABXpy 5 | .. image:: https://anaconda.org/coml/abx/badges/version.svg 6 | :target: https://anaconda.org/coml/abx 7 | .. image:: https://zenodo.org/badge/23788452.svg 8 | :target: https://zenodo.org/badge/latestdoi/23788452 9 | 10 | ABX discrimination test 11 | ======================= 12 | 13 | ABX discrimination is a term that is used for three stimuli presented 14 | on an ABX trial. The third is the focus. The first two stimuli (A 15 | and B) are standard, S1 and S2 in a randomly chosen order, and the 16 | subjects' task is to choose which of the two is matched by the final 17 | stimulus (X). (Glottopedia) 18 | 19 | This package contains the operations necessary to initialize, 20 | calculate and analyse the results of an ABX discrimination task. 21 | 22 | Check out the full documentation at https://docs.cognitive-ml.fr/ABXpy. 23 | 24 | Organisation 25 | ------------ 26 | 27 | It is composed of 3 main modules and other submodules. 28 | 29 | - `task module 30 | `_ is 31 | used for creating a new task and preprocessing. 32 | 33 | - `distances package 34 | `_ is 35 | used for calculating the distances necessary for the score 36 | calculation. 37 | 38 | - `score module 39 | `_ 40 | is used for computing the score of a task. 41 | 42 | - `analyze module 43 | `_ 44 | is used for analysing the results. 45 | 46 | The features can be calculated in numpy via external tools, and made 47 | compatible with this package with the `h5features module 48 | `_, or 49 | directly calculated with one of our tools like `shennong 50 | `_. 51 | 52 | 53 | The pipeline 54 | ------------ 55 | 56 | +-------------------+----------+-----------------+ 57 | | In | Module | Out | 58 | +===================+==========+=================+ 59 | | - data.item | task | - data.abx | 60 | | - parameters | | | 61 | +-------------------+----------+-----------------+ 62 | | - data.abx | distance | - data.distance | 63 | | - data.features | | | 64 | | - distance | | | 65 | +-------------------+----------+-----------------+ 66 | | - data.abx | score | - data.score | 67 | | - data.distance | | | 68 | +-------------------+----------+-----------------+ 69 | | - data.abx | analyse | - data.csv | 70 | | - data.score | | | 71 | +-------------------+----------+-----------------+ 72 | 73 | See `Files Format 74 | `_ for a 75 | description of the files used as input and output. 76 | 77 | 78 | The task 79 | -------- 80 | 81 | According to what you want to study, it is important to characterise 82 | the ABX triplets. You can characterise your task along 3 axes: on, 83 | across and by a certain label. 84 | 85 | An example of ABX triplet: 86 | 87 | +------+------+------+ 88 | | A | B | X | 89 | +======+======+======+ 90 | | on_1 | on_2 | on_1 | 91 | +------+------+------+ 92 | | ac_1 | ac_1 | ac_2 | 93 | +------+------+------+ 94 | | by | by | by | 95 | +------+------+------+ 96 | 97 | A and X share the same 'on' attribute; A and B share the same 'across' 98 | attribute; A,B and X share the same 'by' attribute. 99 | 100 | Example of use 101 | -------------- 102 | 103 | See ``examples/complete_run.sh`` for a command line run and 104 | ``examples/complete_run.py`` for a Python utilisation. 105 | 106 | 107 | Installation 108 | ------------ 109 | 110 | The recommended installation on linux and macos is using `conda 111 | `_:: 112 | 113 | conda install -c coml abx 114 | 115 | Alternatively you may want to install it from sources. First clone 116 | this repository and go to its root directory. Then :: 117 | 118 | conda env create -n abx -f environment.yml 119 | source activate abx 120 | make install 121 | make test 122 | 123 | 124 | Build the documentation 125 | ----------------------- 126 | 127 | To build the documentation in the folder ``ABXpy/build/doc/html``, 128 | simply have a:: 129 | 130 | make doc 131 | 132 | 133 | Citation 134 | -------- 135 | 136 | If you use this software in your research, please cite: 137 | 138 | `ABX-discriminability measures and applications 139 | `_, 140 | Schatz T., Université Paris 6 (UPMC), 2016. 141 | -------------------------------------------------------------------------------- /doc/FilesFormat.rst: -------------------------------------------------------------------------------- 1 | Files format 2 | ============ 3 | 4 | This package uses several types of files, this section describe them all. 5 | 6 | Dataset 7 | ------- 8 | 9 | Extension: ``.item`` 10 | 11 | This file indexes the database on which the ABX task is executed. It 12 | is a regular text file and should have the following structure: 13 | 14 | ======= ======= ====== ======== ======= ======= 15 | #file onset offset #label 1 label 2 label 3 16 | ======= ======= ====== ======== ======= ======= 17 | file 1 start 1 stop 1 value 1 value 1 value 1 18 | file 2 start 2 stop 2 value 2 value 1 value 1 19 | file 3 start 3 stop 3 value 3 value 1 value 1 20 | ======= ======= ====== ======== ======= ======= 21 | 22 | - the first line must be a header line beginning with the 3 fields 23 | **#file onset offset**. 24 | - **#file** is the name of the file minus the extension. Note that the 25 | '#' at the begining is mandatory. 26 | - **onset** is the instant when the sound start. 27 | - **offset** is the instant when the sound end. 28 | - the **label** columns are various regressors relevant to the 29 | discrimination task. Note that the first label column must start 30 | with a **'#'**. 31 | 32 | Features file 33 | ------------- 34 | 35 | Extension: ``.features`` or ``.h5f`` 36 | 37 | This file contains the features and the center time of each window in 38 | the `h5features`_ format. This is a special `hdf5`_ file with the 39 | following attributes: 40 | 41 | - **features** a 2D arrays with the 'feature' dimension along the 42 | columns and the 'time' dimension along the lines. 43 | - **times** a 1D array with the center time of each window. 44 | - **files** the basename of the files from which the features are 45 | extracted. Note that it does not contain the full absolute path nor 46 | the relative path of the files, each file must have a unique name. 47 | 48 | Task file 49 | --------- 50 | 51 | Extension: ``.abx`` 52 | 53 | This file can be generated by the task module. It is a `hdf5`_ 54 | file. It contains all the triplets and the resulting pairs. The 55 | elements are grouped by their 'by' attribute (all the elements with 56 | the same by attributes belong to the same block) 57 | 58 | The structure is as follow: 59 | 60 | data.abx 61 | 62 | - triplets 63 | 64 | - by0: (3 x ?)-array referencing all the possible triplets 65 | sharing a 'by' value of by0 66 | - by1 67 | - etc. 68 | 69 | - unique_pairs (All the pairs AX and BX, useful to calculate the 70 | distances. Note that a pair is designated by a single number due to 71 | a special encoding) 72 | 73 | - by0: 1D-array referencing all the pairs sharing a 'by' value 74 | of by0. Note that this is only 1D instead of 2D due to a 75 | special encoding of the pairs. Let 'n' be the number of 76 | items in the block, 'a' be the index of the first item of 77 | the pair and 'b' the index of the second item: the index of 78 | the pair 'p' = n*a + b 79 | - etc. 80 | 81 | - regressors (infos of the item file in a computer efficient format) 82 | - feat_dbs (infos of the item file in a computer efficient format) 83 | 84 | Distance file 85 | ------------- 86 | 87 | Extension: ``.distance`` 88 | 89 | This file contains the distances between the two members of each 90 | unique pair. The distances are store by 'by' block and in the same 91 | order as the unique_pairs in the `Task file`_. 92 | 93 | - distances 94 | 95 | - by0: 1D-array containing the distances between the two members 96 | of each pair. 97 | - by1 98 | - etc. 99 | 100 | Score file 101 | ---------- 102 | 103 | Extension: ``.score`` 104 | 105 | This file contains the score of each triplets. The score is 1 when X 106 | is closer to A and -1 when X is closer to B. The score are stored by 107 | 'by' block and in the same order as the triplets in the `Task file`_. 108 | 109 | - scores 110 | 111 | - by0: 1D-array of integers containing the score of each triplet. 112 | - by1 113 | - etc. 114 | 115 | Analyse file 116 | ------------ 117 | 118 | Extension: ``.csv`` 119 | 120 | The output file of the ABX baseline, in a human readable 121 | format. Contains the average results collapsed over triplets sharing 122 | the same on, across and by attributes. It uses a score of 1 when X is 123 | closer to A and 0 when X is closer to B. 124 | 125 | The extensions _1 and _2 to the labels name follow the following 126 | convention: 127 | 128 | +------+------+------+ 129 | | A | B | X | 130 | +======+======+======+ 131 | | on_1 | on_2 | on_1 | 132 | +------+------+------+ 133 | | ac_1 | ac_1 | ac_2 | 134 | +------+------+------+ 135 | 136 | Example for a task on 'on', across 'ac' and by 'by': 137 | 138 | +------+------+------+------+----+-------+---+ 139 | | on_1 | ac_1 | ac_2 | on_2 | by | score | n | 140 | +======+======+======+======+====+=======+===+ 141 | | v0 | v0 | v1 | v1 | v0 | 0.2 | 5 | 142 | +------+------+------+------+----+-------+---+ 143 | | v1 | v1 | v0 | v0 | v0 | 0.7 | 3 | 144 | +------+------+------+------+----+-------+---+ 145 | 146 | - **on_1** value of 'on' label for A and X 147 | - **on_2** value of 'on' label for B 148 | - **ac_1** value of 'ac' label for A and B 149 | - **ac_2** value of 'ac' label for X 150 | - **by** value of 'by' label for A, B and X 151 | - **score** average score for those triplets 152 | - **n** number of triplets 153 | 154 | .. _hdf5: http://www.hdfgroup.org/HDF5/ 155 | .. _h5features: https://h5features.readthedocs.io 156 | -------------------------------------------------------------------------------- /ABXpy/sideop/filter_manager.py: -------------------------------------------------------------------------------- 1 | # make sure the rest of the ABXpy package is accessible 2 | 3 | import ABXpy.sideop.side_operations_manager as side_operations_manager 4 | import ABXpy.dbfun.dbfun_compute as dbfun_compute 5 | import ABXpy.dbfun.dbfun_lookuptable as dbfun_lookuptable 6 | import ABXpy.dbfun.dbfun_column as dbfun_column 7 | 8 | import numpy as np 9 | 10 | 11 | class FilterManager(side_operations_manager.SideOperationsManager): 12 | 13 | """Manage the filters on attributes (on, across, by) or elements (A, B, X) 14 | for further processing""" 15 | 16 | def __init__(self, db_hierarchy, on, across, by, filters): 17 | side_operations_manager.SideOperationsManager.__init__( 18 | self, db_hierarchy, on, across, by) 19 | # this case is specific to filters, it applies a generic filter to the 20 | # database before considering A, B and X stuff. 21 | self.generic = [] 22 | 23 | # associate each of the provided filters to the appropriate point in 24 | # the computation flow 25 | # filt can be: the name of a column of the database (possibly 26 | # extended), the name of lookup file, the name of a script, a script 27 | # under the form of a string (that doesnt end by .dbfun...) 28 | for filt in filters: 29 | # instantiate appropriate dbfun 30 | if filt in self.extended_cols: # column already in db 31 | db_fun = dbfun_column.DBfun_Column(filt, indexed=False) 32 | # evaluate context is wasteful in this case... not even 33 | # necessary to have a dbfun at all 34 | elif len(filt) >= 6 and filt[-6:] == '.dbfun': # lookup table 35 | # ask for re-interpreted indexed outputs 36 | db_fun = dbfun_lookuptable.DBfun_LookupTable( 37 | filt, indexed=False) 38 | else: # on the fly computation 39 | db_fun = dbfun_compute.DBfun_Compute(filt, self.extended_cols) 40 | self.add(db_fun) 41 | 42 | def classify_generic(self, elements, db_fun, db_variables): 43 | # check if there are only non-extended names and, only if this is the 44 | # case, instantiate 'generic' field of db_variables 45 | if {s for r, s in elements} == set(['']): 46 | db_variables['generic'] = set(elements) 47 | self.generic.append(db_fun) 48 | self.generic_context['generic'].update(db_variables['generic']) 49 | elements = {} 50 | return elements, db_variables 51 | 52 | def by_filter(self, by_values): 53 | return singleton_filter(self.evaluate_by(by_values)) 54 | 55 | def generic_filter(self, by_values, db): 56 | return db.iloc[vectorial_filter(lambda context: self.evaluate_generic(by_values, db, context), np.arange(len(db)))] 57 | 58 | def on_across_by_filter(self, on_across_by_values): 59 | return singleton_filter(self.evaluate_on_across_by(on_across_by_values)) 60 | 61 | def A_filter(self, on_across_by_values, db, indices): 62 | # Caution: indices contains db-related indices 63 | # but the returned result contains indices with respect to indices 64 | indices_ind = np.arange(len(indices)) 65 | return vectorial_filter(lambda context: self.evaluate_A(on_across_by_values, db, indices, context), indices_ind) 66 | 67 | def B_filter(self, on_across_by_values, db, indices): 68 | # Caution: indices contains db-related indices 69 | # but the returned result contains indices with respect to indices 70 | indices_ind = np.arange(len(indices)) 71 | return vectorial_filter(lambda context: self.evaluate_B(on_across_by_values, db, indices, context), indices_ind) 72 | 73 | def X_filter(self, on_across_by_values, db, indices): 74 | # Caution: indices contains db-related indices 75 | # but the returned result contains indices with respect to indices 76 | indices_ind = np.arange(len(indices)) 77 | return vectorial_filter(lambda context: self.evaluate_X(on_across_by_values, db, indices, context), indices_ind) 78 | 79 | def ABX_filter(self, on_across_by_values, db, triplets): 80 | # triplets contains db-related indices 81 | # the returned result contains indices with respect to triplets 82 | indices = np.arange(len(triplets)) 83 | return vectorial_filter(lambda context: self.evaluate_ABX(on_across_by_values, db, triplets, context), indices) 84 | 85 | 86 | def singleton_filter(generator): 87 | keep = True 88 | for result in generator: 89 | if not(result): 90 | keep = False 91 | break 92 | return keep 93 | 94 | 95 | def vectorial_filter(generator, indices): 96 | """ 97 | 98 | .. note:: To allow a lazy evaluation of the filter, the context is filtered 99 | explicitly which acts on the generator by a side-effect (dict being 100 | mutable in python) 101 | """ 102 | kept = np.array(indices) 103 | context = {} 104 | for result in generator(context): 105 | still_up = np.where(result)[0] 106 | kept = kept[still_up] 107 | for var in context: 108 | # keep testing only the case that are still possibly True 109 | context[var] = [context[var][e] for e in still_up] 110 | # FIXME wouldn't using only numpy arrays be more performant ? 111 | if not(kept.size): 112 | break 113 | return kept 114 | -------------------------------------------------------------------------------- /ABXpy/distance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import editdistance 4 | import os 5 | import sys 6 | import warnings 7 | 8 | from ABXpy.distances import distances 9 | import ABXpy.distances.metrics.dtw as dtw 10 | import ABXpy.distances.metrics.kullback_leibler as kl 11 | import ABXpy.distances.metrics.cosine as cosine 12 | 13 | def default_distance(x, y, normalized): 14 | """ Dynamic time warping cosine distance 15 | 16 | The "feature" dimension is along the columns and the "time" dimension 17 | along the lines of arrays x and y 18 | """ 19 | if x.shape[0] > 0 and y.shape[0] > 0: 20 | # x and y are not empty 21 | d = dtw.dtw(x, y, cosine.cosine_distance, 22 | normalized=normalized) 23 | elif x.shape[0] == y.shape[0]: 24 | # both x and y are empty 25 | d = 0 26 | else: 27 | # x or y is empty 28 | d = np.inf 29 | return d 30 | 31 | def dtw_kl_distance(x, y, normalized=True): 32 | """ Dynamic time warping cosine distance 33 | 34 | The "feature" dimension is along the columns and the "time" dimension 35 | along the lines of arrays x and y 36 | """ 37 | if x.shape[0] > 0 and y.shape[0] > 0: 38 | # x and y are not empty 39 | d = dtw.dtw(x, y, kl.kl_divergence, 40 | normalized=normalized) 41 | elif x.shape[0] == y.shape[0]: 42 | # both x and y are empty 43 | d = 0 44 | else: 45 | # x or y is empty 46 | d = np.inf 47 | return d 48 | 49 | def edit_distance(x, y): 50 | """Levenshtein Distance 51 | 52 | The "feature" dimension is along the columns and the "time" dimension 53 | along the lines of arrays x and y 54 | """ 55 | # convert arrays to tuple, to evaluate w/ editdistance 56 | def totuple(a): 57 | try: 58 | return tuple(totuple(i) for i in a) 59 | except TypeError: 60 | return a 61 | 62 | if x.shape[0] > 0 and y.shape[0] > 0: 63 | # x and y are not empty 64 | d = editdistance.eval(totuple(x), totuple(y)) 65 | elif x.shape[0] == y.shape[0]: 66 | # both x and y are empty 67 | d = 0 68 | else: 69 | # x or y is empty 70 | d = np.inf 71 | return d 72 | 73 | def run(features, task, output, normalized, 74 | distance=None, njobs=1, group='features'): 75 | njobs = int(njobs) 76 | if distance: 77 | if distance=="levenshtein": 78 | distancefun = edit_distance 79 | elif distance=="dtw_kl": 80 | distancefun = dtw_kl_distance 81 | else: 82 | distancepair = distance.split('.') 83 | distancemodule = distancepair[0] 84 | distancefunction = distancepair[1] 85 | path, mod = os.path.split(distancemodule) 86 | sys.path.insert(0, path) 87 | distancefun = getattr(__import__(mod), distancefunction) 88 | else: 89 | distancefun = default_distance 90 | 91 | distances.compute_distances( 92 | features, group, task, output, 93 | distancefun, normalized=normalized, n_cpu=njobs) 94 | 95 | 96 | def main(): 97 | parser = argparse.ArgumentParser( 98 | description='Compute distances for the ABX discrimination task') 99 | 100 | parser.add_argument( 101 | 'features', 102 | help='h5features file containing the feature to evaluate') 103 | 104 | parser.add_argument( 105 | '-g', '--group', default='features', 106 | help='group to read in the h5features file, default is %(default)s') 107 | 108 | parser.add_argument( 109 | 'task', help='task file') 110 | 111 | parser.add_argument( 112 | 'output', help='output file for distance pairs') 113 | 114 | parser.add_argument( 115 | '-d', '--distance', metavar='distancemodule.distancefunction', 116 | help='''Define distance module to use.\n\n''' 117 | '''Use "-d levenshtein" to use the Levenshtein distance''' 118 | ''' instead of DTW.\n''' 119 | '''Use "-d dtw_kl" to use the Kullback Leibler divergence''' 120 | ''' with the DTW, instead of the default cosine distance.\n\n''' 121 | '''Use -d distancemodule.distancefunction to use you own''' 122 | ''' distance\n\n''' 123 | '''If not set, it defaults to dtw cosine distance''') 124 | 125 | parser.add_argument( 126 | '-j', '--njobs', type=int, default=1, 127 | help='number of cpus to use') 128 | 129 | parser.add_argument( 130 | '-n', '--normalization', type=int, default=None, 131 | help='if dtw distance selected, compute with normalization or with ' 132 | 'sum. If put to 1 : computes with normalization, if put to 0 : ' 133 | 'computes with sum. Common choice is to use normalization (-n 1)') 134 | 135 | args = parser.parse_args() 136 | 137 | if os.path.exists(args.output): 138 | warnings.warn("Overwriting distance file " + args.output, UserWarning) 139 | os.remove(args.output) 140 | 141 | # if dtw distance selected, fore use of normalization parameter : 142 | if (args.distance is None and args.normalization is None): 143 | sys.exit("ERROR : DTW normalization parameter not specified !") 144 | 145 | run(args.features, args.task, args.output, normalized=args.normalization, 146 | distance=args.distance, njobs=args.njobs, group=args.group) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. linkcheck to check all external links for integrity 37 | echo. doctest to run all doctests embedded in the documentation if enabled 38 | goto end 39 | ) 40 | 41 | if "%1" == "clean" ( 42 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 43 | del /q /s %BUILDDIR%\* 44 | goto end 45 | ) 46 | 47 | if "%1" == "html" ( 48 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 49 | if errorlevel 1 exit /b 1 50 | echo. 51 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 52 | goto end 53 | ) 54 | 55 | if "%1" == "dirhtml" ( 56 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 57 | if errorlevel 1 exit /b 1 58 | echo. 59 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 60 | goto end 61 | ) 62 | 63 | if "%1" == "singlehtml" ( 64 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 68 | goto end 69 | ) 70 | 71 | if "%1" == "pickle" ( 72 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished; now you can process the pickle files. 76 | goto end 77 | ) 78 | 79 | if "%1" == "json" ( 80 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished; now you can process the JSON files. 84 | goto end 85 | ) 86 | 87 | if "%1" == "htmlhelp" ( 88 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can run HTML Help Workshop with the ^ 92 | .hhp project file in %BUILDDIR%/htmlhelp. 93 | goto end 94 | ) 95 | 96 | if "%1" == "qthelp" ( 97 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 98 | if errorlevel 1 exit /b 1 99 | echo. 100 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 101 | .qhcp project file in %BUILDDIR%/qthelp, like this: 102 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\ABXpy.qhcp 103 | echo.To view the help file: 104 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\ABXpy.ghc 105 | goto end 106 | ) 107 | 108 | if "%1" == "devhelp" ( 109 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 110 | if errorlevel 1 exit /b 1 111 | echo. 112 | echo.Build finished. 113 | goto end 114 | ) 115 | 116 | if "%1" == "epub" ( 117 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 118 | if errorlevel 1 exit /b 1 119 | echo. 120 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 121 | goto end 122 | ) 123 | 124 | if "%1" == "latex" ( 125 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 129 | goto end 130 | ) 131 | 132 | if "%1" == "text" ( 133 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The text files are in %BUILDDIR%/text. 137 | goto end 138 | ) 139 | 140 | if "%1" == "man" ( 141 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 145 | goto end 146 | ) 147 | 148 | if "%1" == "texinfo" ( 149 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 150 | if errorlevel 1 exit /b 1 151 | echo. 152 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 153 | goto end 154 | ) 155 | 156 | if "%1" == "gettext" ( 157 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 158 | if errorlevel 1 exit /b 1 159 | echo. 160 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 161 | goto end 162 | ) 163 | 164 | if "%1" == "changes" ( 165 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 166 | if errorlevel 1 exit /b 1 167 | echo. 168 | echo.The overview file is in %BUILDDIR%/changes. 169 | goto end 170 | ) 171 | 172 | if "%1" == "linkcheck" ( 173 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 174 | if errorlevel 1 exit /b 1 175 | echo. 176 | echo.Link check complete; look for any errors in the above output ^ 177 | or in %BUILDDIR%/linkcheck/output.txt. 178 | goto end 179 | ) 180 | 181 | if "%1" == "doctest" ( 182 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 183 | if errorlevel 1 exit /b 1 184 | echo. 185 | echo.Testing of doctests in the sources finished, look at the ^ 186 | results in %BUILDDIR%/doctest/output.txt. 187 | goto end 188 | ) 189 | 190 | :end 191 | -------------------------------------------------------------------------------- /ABXpy/dbfun/lookuptable_connector.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | AST visitor class for finding calls to lookup table (.dbfun) files in 4 | a script and make the script ready for execution. 5 | 6 | There is a restriction: when in a lookup table calling node, check 7 | that the node.code_nodes concatened of the children is empty: 8 | i.e. hierarchical calls to auxiliary h5 file are not allowed with a 9 | depth larger than 1. We could allow deeper calls, but would there be 10 | any practical benefit ? Probably not because the input to a lookup 11 | table call can only be of the type of a column of the database (at 12 | least with the current system). 13 | 14 | """ 15 | 16 | import ast 17 | import uuid 18 | 19 | 20 | class LookupTableConnector(ast.NodeTransformer): 21 | def __init__(self, script, aux_functions, aliases, *args, **kwargs): 22 | ast.NodeTransformer.__init__(self, *args, **kwargs) 23 | self.aux_functions = aux_functions 24 | self.aliases = aliases 25 | self.script = script 26 | 27 | def check_Call(self, node): 28 | if not(isinstance(node.func.ctx, ast.Load) and not(node.keywords) and 29 | node.kwargs is None and node.starargs is None): 30 | raise ValueError( 31 | 'Call to function %s defined in an auxiliary h5 file is not ' 32 | 'correct in script: %s' % (node.func.id, self.script)) 33 | 34 | def check_Subscript(self, node, func_name): 35 | error = ValueError( 36 | 'Call to function %s defined in an auxiliary h5 file is not ' 37 | 'correct in script: %s' % (func_name, self.script)) 38 | 39 | if not(isinstance(node.slice, ast.Index) and 40 | isinstance(node.ctx, ast.Load)): 41 | raise error 42 | 43 | if not(isinstance(node.slice.value, ast.Tuple) and 44 | isinstance(node.slice.value.ctx, ast.Load)): 45 | raise error 46 | 47 | # output column must appear explicitly as strings 48 | if not(all([isinstance(e, ast.Str) for e in node.slice.value.elts])): 49 | raise error 50 | 51 | def check_flatness(self, node): 52 | if node.code_nodes: 53 | raise ValueError( 54 | 'There are calls to functions defined in auxiliary files whose' 55 | ' arguments are defined in terms of calls to other functions ' 56 | 'defined in auxiliary files in script: %s' % self.script) 57 | 58 | # extract and store in a code_node all useful information about the call 59 | # and replace node in its parent by a ast.Name node pointing to a random 60 | # but duly stored 'varname' 61 | def bundle_call_info(self, node, call_node): 62 | code_node = {} 63 | if not(node is call_node): 64 | code_node['output_cols'] = [e.s for e in node.slice.value.elts] 65 | code_node['child_asts'] = [ast.Expression(arg) for arg in node.args] 66 | code_node['function'] = self.aux_functions( 67 | self.aliases.index(call_node.func.id)) 68 | code_node['varname'] = 'var_' + str(uuid.uuid4()).translate(None, '-') 69 | new_node = ast.copy_location( 70 | ast.Name(ctx=ast.Load(), id=code_node['varname']), node) 71 | new_node.code_nodes = [code_node] 72 | return new_node 73 | 74 | # case of call with output columns specified: aux_file_alias(in_1, ..., 75 | # in_k).out['out_name_1',..., 'out_name_n'] 76 | def visit_Subscript(self, node): 77 | # check if it is a call to an auxiliary h5 file 78 | aux_call = False 79 | if isinstance(node.value, ast.Attribute) and node.value.attr == 'out': 80 | call_node = node.value.value 81 | if ((isinstance(call_node, ast.Call) and call_node.func.id 82 | in self.aliases)): 83 | aux_call = True 84 | if not(aux_call): 85 | # visit all children and bubble up code_nodes 86 | node = self.generic_visit(node) 87 | return node 88 | else: 89 | # first check correctness of the call 90 | self.check_Call(call_node) 91 | self.check_Subscript(node, call_node.func.id) 92 | # check that the aux function call is flat (i.e. there are no other 93 | # aux function call in the call_node children) 94 | # generic visit here, not visit_Call, to bubble up code_nodes 95 | call_node = self.generic_visit(call_node) 96 | node.code_nodes = call_node.code_nodes 97 | self.check_flatness(node) 98 | # extract and store in a code_node all useful information about 99 | # the call 100 | # and replace node in its parent by a ast.Name node pointing to a 101 | # random but duly stored 'varname' 102 | return self.bundle_call_info(node, call_node) 103 | 104 | # case of call for all outputs: aux_file_alias(in_1, ..., in_k) 105 | def visit_Call(self, node): 106 | # check if it is not a call to an auxiliary h5 file 107 | if not(node.func.id in self.aliases): 108 | # visit all children and bubble up code_nodes 109 | node = self.generic_visit(node) 110 | return node 111 | else: 112 | # first check correctness of the call 113 | self.check_Call(node) 114 | # check that the aux function call is flat (i.e. there are no other 115 | # aux function call in the node children) 116 | node = self.generic_visit(node) # bubble up code nodes 117 | self.check_flatness(node) 118 | return self.bundle_call_info(node, node) 119 | 120 | def generic_visit(self, node): 121 | # visit all children 122 | super(ast.NodeVisitor, self).generic_visit(node) 123 | # bubble up code_nodes from the children if any (cleaning up children 124 | # at the same time to keep sound asts) 125 | node.code_nodes = [] 126 | for child in ast.iter_child_nodes(node): 127 | node.code_nodes = node.code_nodes + child.code_nodes 128 | del(child.code_nodes) 129 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 16 | 17 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 18 | 19 | help: 20 | @echo "Please use \`make ' where is one of" 21 | @echo " html to make standalone HTML files" 22 | @echo " dirhtml to make HTML files named index.html in directories" 23 | @echo " singlehtml to make a single large HTML file" 24 | @echo " pickle to make pickle files" 25 | @echo " json to make JSON files" 26 | @echo " htmlhelp to make HTML files and a HTML help project" 27 | @echo " qthelp to make HTML files and a qthelp project" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 31 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 32 | @echo " text to make text files" 33 | @echo " man to make manual pages" 34 | @echo " texinfo to make Texinfo files" 35 | @echo " info to make Texinfo files and run them through makeinfo" 36 | @echo " gettext to make PO message catalogs" 37 | @echo " changes to make an overview of all changed/added/deprecated items" 38 | @echo " linkcheck to check all external links for integrity" 39 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 40 | 41 | clean: 42 | -rm -rf $(BUILDDIR)/* 43 | 44 | html: 45 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 46 | @echo 47 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 48 | 49 | dirhtml: 50 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 51 | @echo 52 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 53 | 54 | singlehtml: 55 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 56 | @echo 57 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 58 | 59 | pickle: 60 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 61 | @echo 62 | @echo "Build finished; now you can process the pickle files." 63 | 64 | json: 65 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 66 | @echo 67 | @echo "Build finished; now you can process the JSON files." 68 | 69 | htmlhelp: 70 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 71 | @echo 72 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 73 | ".hhp project file in $(BUILDDIR)/htmlhelp." 74 | 75 | qthelp: 76 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 77 | @echo 78 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 79 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 80 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/ABXpy.qhcp" 81 | @echo "To view the help file:" 82 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/ABXpy.qhc" 83 | 84 | devhelp: 85 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 86 | @echo 87 | @echo "Build finished." 88 | @echo "To view the help file:" 89 | @echo "# mkdir -p $$HOME/.local/share/devhelp/ABXpy" 90 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/ABXpy" 91 | @echo "# devhelp" 92 | 93 | epub: 94 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 95 | @echo 96 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 97 | 98 | latex: 99 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 100 | @echo 101 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 102 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 103 | "(use \`make latexpdf' here to do that automatically)." 104 | 105 | latexpdf: 106 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 107 | @echo "Running LaTeX files through pdflatex..." 108 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 109 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 110 | 111 | text: 112 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 113 | @echo 114 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 115 | 116 | man: 117 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 118 | @echo 119 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 120 | 121 | texinfo: 122 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 123 | @echo 124 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 125 | @echo "Run \`make' in that directory to run these through makeinfo" \ 126 | "(use \`make info' here to do that automatically)." 127 | 128 | info: 129 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 130 | @echo "Running Texinfo files through makeinfo..." 131 | make -C $(BUILDDIR)/texinfo info 132 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 133 | 134 | gettext: 135 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 136 | @echo 137 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 138 | 139 | changes: 140 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 141 | @echo 142 | @echo "The overview file is in $(BUILDDIR)/changes." 143 | 144 | linkcheck: 145 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 146 | @echo 147 | @echo "Link check complete; look for any errors in the above output " \ 148 | "or in $(BUILDDIR)/linkcheck/output.txt." 149 | 150 | doctest: 151 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 152 | @echo "Testing of doctests in the sources finished, look at the " \ 153 | "results in $(BUILDDIR)/doctest/output.txt." 154 | -------------------------------------------------------------------------------- /bin/ABXrun.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Command-line program to run the complete ABX pipeline""" 3 | 4 | import argparse 5 | import json 6 | import os 7 | 8 | import ABXpy.task 9 | import ABXpy.distances.distances as distances 10 | import ABXpy.distances.metrics.cosine as cosine 11 | import ABXpy.distances.metrics.dtw as dtw 12 | import ABXpy.score as score 13 | import ABXpy.analyze as analyze 14 | 15 | 16 | def get_name(filename): 17 | return os.path.splitext(os.path.basename(filename))[0] 18 | 19 | 20 | def get_arg(key, args): 21 | if key in args: 22 | return args[key] 23 | else: 24 | return None 25 | 26 | 27 | def dtw_cosine_distance(x, y): 28 | return dtw.dtw(x, y, cosine.cosine_distance) 29 | 30 | 31 | def test_analyze(itemfile, featurefile, args, taskfile=None, distance=None, 32 | distancefile=None, scorefile=None, analyzefile=None, 33 | filename=None): 34 | 35 | on = get_arg('on', args) 36 | assert on, ("The 'on' argument was not found, this argument is mandatory" 37 | "for the task") 38 | 39 | across = get_arg('across', args) 40 | by = get_arg('by', args) 41 | filters = get_arg('filters', args) 42 | reg = get_arg('reg', args) 43 | 44 | if not filename: 45 | filename = '_'.join( 46 | filter(None, [get_name(itemfile), 47 | get_name(featurefile), 48 | str(on), 49 | str(across), 50 | str(by)])) 51 | 52 | if not distancefile: 53 | distancefile = filename + '.distance' 54 | 55 | if not scorefile: 56 | scorefile = filename + '.score' 57 | 58 | if not analyzefile: 59 | analyzefile = filename + '.csv' 60 | 61 | task = ABXpy.task.Task(itemfile, on, across, by, filters, reg, 62 | features=featurefile) 63 | task.generate_triplets() 64 | 65 | if not distance: 66 | distance = dtw_cosine_distance 67 | distances.compute_distances(featurefile, '/features/', taskfile, 68 | distancefile, distance) 69 | 70 | score.score(taskfile, distancefile, scorefile) 71 | 72 | analyze.analyze(scorefile, taskfile, analyzefile) 73 | 74 | 75 | def parse_args(): 76 | parser = argparse.ArgumentParser( 77 | prog='ABXrun.py', 78 | # formatter_class=argparse.RawDescriptionHelpFormatter, 79 | description='Run the complete ABX pipeline.', 80 | epilog='Example usage: ./ABXrun.py data.item data.feature data.conf') 81 | 82 | parser.add_argument( 83 | 'itemfile', metavar='itemfile', 84 | help='Input item file in item format, e.g. data.item') 85 | 86 | parser.add_argument( 87 | 'featurefile', metavar='featurefile', 88 | help='Input feature file in h5features format, e.g. data.features') 89 | 90 | parser.add_argument( 91 | 'config', metavar='configfile', 92 | help='Input config file in json format, e.g. data.conf.\n' 93 | 'This file should at least contain the task parameters. ' 94 | 'You can also include the filenames you want to use and the saga' 95 | ' parameters.') 96 | 97 | parser.add_argument( 98 | '--taskfile', metavar='taskfile', 99 | required=False, 100 | help='Output task file where the task information will be stored, e.g.' 101 | ' data.abx.') 102 | 103 | parser.add_argument( 104 | '--distancefile', metavar='distancefile', 105 | required=False, 106 | help='Output distance file where the distances between pairs will be ' 107 | 'stored, e.g. data.distance') 108 | 109 | parser.add_argument( 110 | '--analyzefile', metavar='analyzefile', 111 | required=False, 112 | help='Output analyze file where the collapsed results will be stored, ' 113 | 'e.g. data.csv') 114 | 115 | parser.add_argument( 116 | '--scorefile', metavar='scorefile', 117 | required=False, 118 | help='Output score file where the score of the triplets will be stored' 119 | ', e.g. data.score') 120 | 121 | parser.add_argument( 122 | '--distance', metavar='distance', 123 | required=False, 124 | help='Callable distance function to be used for distance calculation,' 125 | ' by default the dynamic time warping cosine distance will be ' 126 | 'used') 127 | 128 | parser.add_argument( 129 | '--name', metavar='filename', 130 | required=False, 131 | help='If you specify a filename, all the files generated will have ' 132 | 'the same basename and a standard extension. For instance, the ' 133 | 'task file will be named filename.abx') 134 | 135 | return vars(parser.parse_args()) 136 | 137 | 138 | def main(): 139 | args = parse_args() 140 | 141 | # mandatory args 142 | itemfile = args['itemfile'] 143 | assert os.path.exists(itemfile) 144 | 145 | featurefile = args['featurefile'] 146 | assert os.path.exists(featurefile) 147 | 148 | configfile = args['configfile'] 149 | assert os.path.exists(configfile) 150 | 151 | try: 152 | config = json.load(open(configfile, 'r')) 153 | except IOError: 154 | print('No such file: {}'.format(configfile)) 155 | exit() 156 | assert get_arg('on', config) 157 | 158 | # optional args 159 | taskfile = get_arg('taskfile', args) 160 | if not taskfile: 161 | taskfile = get_arg('taskfile', config) 162 | 163 | distancefile = get_arg('distancefile', args) 164 | if not distancefile: 165 | distancefile = get_arg('distancefile', config) 166 | 167 | scorefile = get_arg('scorefile', args) 168 | if not scorefile: 169 | scorefile = get_arg('scorefile', config) 170 | 171 | analyzefile = get_arg('analyzefile', args) 172 | if not analyzefile: 173 | analyzefile = get_arg('analyzefile', config) 174 | 175 | distance = get_arg('distance', args) 176 | if not distance: 177 | distance = get_arg('distance', config) 178 | 179 | filename = get_arg('filename', args) 180 | if not filename: 181 | filename = get_arg('filename', config) 182 | 183 | test_analyze(itemfile, featurefile, config, taskfile, distance, 184 | distancefile, scorefile, analyzefile, filename) 185 | 186 | 187 | if __name__ == '__main__': 188 | main() 189 | -------------------------------------------------------------------------------- /ABXpy/score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """This module is used for computing the score of a task (see `task Module`_ on 3 | how to create a task) 4 | 5 | This module contains the actual computation of the score. It requires a task 6 | and a distance, and redirect the output in a score file. 7 | 8 | The main function takes a distance file and a task file as input to compute 9 | the score of the task on those distances. X closer to A is associated with 10 | a score of 1 and X closer to B with score of -1. 11 | 12 | The distances between pairs in the distance file must be ordered the same 13 | way as the pairs in the task file, and the triplet score int the output 14 | file will be ordered the same way as the triplets in the task file. 15 | 16 | Usage 17 | ----- 18 | Form the command line: 19 | 20 | .. code-block:: bash 21 | 22 | python score.py data.abx data.distance data.score 23 | 24 | In python: 25 | 26 | .. code-block:: python 27 | 28 | import ABXpy.task 29 | import ABXpy.score 30 | # create a new task: 31 | myTask = ABXpy.task.Task('data.item', 'on_feature', 'across_feature', \ 32 | 'by_feature', filters=my_filters, regressors=my_regressors) 33 | myTask.generate_triplets() 34 | #initialise distance 35 | #TODO shouldn't this be available from score 36 | # calculate the scores: 37 | ABXpy.score('data.abx', 'myDistance.???', 'data.score') 38 | 39 | """ 40 | 41 | import argparse 42 | import h5py 43 | import numpy as np 44 | import os 45 | 46 | import ABXpy.h5tools.h52np as h52np 47 | import ABXpy.misc.type_fitting as type_fitting 48 | 49 | 50 | # FIXME: include distance computation here 51 | def score(task_file, distance_file, score_file=None, score_group='scores'): 52 | """Calculate the score of a task and put the results in a hdf5 file. 53 | 54 | Parameters 55 | ---------- 56 | task_file : string 57 | The hdf5 file containing the task (with the triplets and pairs 58 | generated) 59 | distance_file : string 60 | The hdf5 file containing the distances between the pairs 61 | score_file : string, optional 62 | The hdf5 file that will contain the results 63 | """ 64 | if score_file is None: 65 | (basename_task, _) = os.path.splitext(task_file) 66 | (basename_dist, _) = os.path.splitext(distance_file) 67 | score_file = basename_task + '_' + basename_dist + '.score' 68 | # file verification: 69 | assert os.path.exists(task_file), 'Cannot find task file ' + task_file 70 | assert os.path.exists(distance_file), ('Cannot find distance file ' + 71 | distance_file) 72 | assert not os.path.exists(score_file), ('score file already exist ' + 73 | score_file) 74 | # with h5py.File(task_file) as t: 75 | # bys = [by for by in t['triplets']] 76 | # FIXME skip empty by datasets, this should not be necessary anymore when 77 | # empty datasets are filtered at the task file generation level 78 | with h5py.File(task_file, 'r') as t: 79 | bys = t['bys'][...] 80 | # bys = t['feat_dbs'].keys() 81 | n_triplets = t['triplets']['data'].shape[0] 82 | with h5py.File(score_file, 'w') as s: 83 | s.create_dataset('scores', (n_triplets, 1), dtype=np.int8) 84 | for n_by, by in enumerate(bys): 85 | with h5py.File(task_file, 'r') as t, h5py.File(distance_file, 'r') as d: 86 | trip_attrs = t['triplets']['by_index'][n_by] 87 | pair_attrs = t['unique_pairs'].attrs[by] 88 | # FIXME here we make the assumption 89 | # that this fits into memory ... 90 | dis = d['distances']['data'][pair_attrs[1]:pair_attrs[2]][...] 91 | dis = np.reshape(dis, dis.shape[0]) 92 | # FIXME idem + only unique_pairs used ? 93 | pairs = t['unique_pairs']['data'][pair_attrs[1]:pair_attrs[2]][...] 94 | pairs = np.reshape(pairs, pairs.shape[0]) 95 | base = pair_attrs[0] 96 | pair_key_type = type_fitting.fit_integer_type((base) ** 2 - 1, 97 | is_signed=False) 98 | with h52np.H52NP(task_file) as t: 99 | inp = t.add_subdataset('triplets', 'data', indexes=trip_attrs) 100 | idx_start = trip_attrs[0] 101 | for triplets in inp: 102 | triplets = pair_key_type(triplets) 103 | idx_end = idx_start + triplets.shape[0] 104 | 105 | pairs_AX = triplets[:, 0] + base * triplets[:, 2] 106 | # FIXME change the encoding (and type_fitting) so that 107 | # A,B and B,A have the same code ... (take a=min(a,b), 108 | # b=max(a,b)) 109 | pairs_BX = triplets[:, 1] + base * triplets[:, 2] 110 | dis_AX = dis[np.searchsorted(pairs, pairs_AX)] 111 | 112 | dis_BX = dis[np.searchsorted(pairs, pairs_BX)] 113 | scores = (np.int8(dis_AX < dis_BX) - 114 | np.int8(dis_AX > dis_BX)) 115 | # 1 if X closer to A, -1 if X closer to B, 0 if equal 116 | # distance (this doesn't use 0, 1/2, 1 to use the 117 | # compact np.int8 data format) 118 | s['scores'][idx_start:idx_end] = np.reshape(scores, (-1, 1)) 119 | idx_start = idx_end 120 | 121 | 122 | def main(): 123 | # parser (the usage string is specified explicitly because the default 124 | # does not show that the mandatory arguments must come before the mandatory 125 | # ones; otherwise parsing is not possible beacause optional arguments can 126 | # have various numbers of inputs) 127 | parser = argparse.ArgumentParser(usage="%(prog)s task distance [score]", 128 | description='ABX score computation') 129 | # I/O files 130 | g1 = parser.add_argument_group('I/O files') 131 | g1.add_argument('task', help='task file generated by the task module, \ 132 | containing the triplets and the pairs associated to the task \ 133 | specification') 134 | g1.add_argument('distance', help='distance file generated by the distance \ 135 | package, containing the distance between the pairs of a task') 136 | g1.add_argument('score', nargs='?', default=None, help='optional: score \ 137 | file, where the results of the computation will be put') 138 | args = parser.parse_args() 139 | 140 | if os.path.exists(args.score): 141 | print("Warning: overwriting score file {}".format(args.score)) 142 | os.remove(args.score) 143 | score(args.task, args.distance, args.score) 144 | 145 | 146 | if __name__ == '__main__': 147 | main() 148 | -------------------------------------------------------------------------------- /ABXpy/sideop/regressor_manager.py: -------------------------------------------------------------------------------- 1 | # make sure the rest of the ABXpy package is accessible 2 | from six import iteritems 3 | 4 | import ABXpy.sideop.side_operations_manager as side_operations_manager 5 | import ABXpy.dbfun.dbfun_compute as dbfun_compute 6 | import ABXpy.dbfun.dbfun_lookuptable as dbfun_lookuptable 7 | import ABXpy.dbfun.dbfun_column as dbfun_column 8 | 9 | 10 | class RegressorManager(side_operations_manager.SideOperationsManager): 11 | 12 | """Manage the regressors on attributes (on, across, by) or elements (A, B, 13 | X) for further processing 14 | """ 15 | 16 | def __init__(self, db, db_hierarchy, on, across, by, regressors): 17 | side_operations_manager.SideOperationsManager.__init__( 18 | self, db_hierarchy, on, across, by) 19 | # add column functions for the default regressors: on_AB, on_X, 20 | # across_AX(s), across_B(s) (but not the by(s)) 21 | default_regressors = [on[0] + '_1', on[0] + '_2'] 22 | # check if no across were specified 23 | if not(self.across_cols == set(["#across"])): 24 | for col in self.across_cols: 25 | default_regressors.append(col + '_1') 26 | default_regressors.append(col + '_2') 27 | # FIXME add default regressors only if they are not already specified ? 28 | # FIXME do we really need to add the columns deriving from the original 29 | # on and across? 30 | regressors = regressors + default_regressors 31 | 32 | # reg can be: the name of a column of the database (possibly extended), 33 | # the name of lookup file, the name of a script, a script under the 34 | # form of a string (that doesnt end by .dbfun...) 35 | for reg in regressors: 36 | # instantiate appropriate dbfun 37 | if reg in self.extended_cols: # column already in db 38 | col, _ = self.parse_extended_column(reg) 39 | db_fun = dbfun_column.DBfun_Column(reg, db, col, indexed=True) 40 | elif len(reg) >= 6 and reg[-6:] == '.dbfun': # lookup table 41 | # ask for re-interpreted indexed outputs 42 | db_fun = dbfun_lookuptable.DBfun_LookupTable(reg, indexed=True) 43 | else: # on the fly computation 44 | db_fun = dbfun_compute.DBfun_Compute(reg, self.extended_cols) 45 | self.add(db_fun) 46 | 47 | # regressor names and regressor index if needed 48 | 49 | # for generics: generate three versions of the regressor: _A, _B, and _X 50 | def classify_generic(self, elements, db_fun, db_variables): 51 | # check if there are only non-extended names 52 | if {s for r, s in elements} == set(['']): 53 | # for now raise an exception 54 | raise ValueError( 55 | 'You need to explicitly specify the columns for which you want regressors (using _A, _B and _X extensions)') 56 | # FIXME finish the following code to replace the current exception ... 57 | # change the code and/or the synopsis to replace all columns by their name +'_A', '_B', or '_X' 58 | # if db_fun.mode == 'table lookup': 59 | # definition = "with '%s' as reg: reg(%s%s, %s%s, ...)" % (db_fun.h5_file, db_fun.in_names[0], ext, db_fun.in_names[1], ext, ...) 60 | # else: 61 | # definition = f(db_fun.script) # f replaces all occurences of db_fun.extended_variables in the script string by _A,... version 62 | # need a function to regenerate python code from the a modified ast for this. 63 | # is it always a DBfun_Compute ? 64 | # reg_A = dbfun_compute.DBfun_Compute(definition, self.extended_columns) 65 | # reg_B = dbfun_compute.DBfun_Compute(definition, self.extended_columns) 66 | # reg_X = dbfun_compute.DBfun_Compute(definition, self.extended_columns) 67 | # self.add(reg_A) 68 | # self.add(reg_B) 69 | # self.add(reg_X) 70 | #elements = {} 71 | return elements, db_variables 72 | 73 | def set_by_regressors(self, by_values): 74 | self.by_regressors = [result for result in self.evaluate_by(by_values)] 75 | 76 | def set_on_across_by_regressors(self, on_across_by_values): 77 | self.on_across_by_regressors = [ 78 | result for result in self.evaluate_on_across_by(on_across_by_values)] 79 | 80 | def set_A_regressors(self, on_across_by_values, db, indices): 81 | self.A_regressors = [ 82 | result for result in self.evaluate_A(on_across_by_values, db, indices)] 83 | 84 | def set_B_regressors(self, on_across_by_values, db, indices): 85 | self.B_regressors = [ 86 | result for result in self.evaluate_B(on_across_by_values, db, indices)] 87 | 88 | def set_X_regressors(self, on_across_by_values, db, indices): 89 | self.X_regressors = [ 90 | result for result in self.evaluate_X(on_across_by_values, db, indices)] 91 | 92 | # FIXME implement ABX regressors 93 | def set_ABX_regressors(self, on_across_by_values, db, triplets): 94 | raise ValueError('ABX regressors not implemented') 95 | 96 | # FIXME current implem (here and also in dbfun.output_specs), does not 97 | # allow index sharing... 98 | def get_regressor_info(self): 99 | names = [] 100 | indexes = {} 101 | reg_id = 0 102 | reg_id = self.fetch_regressor_info('by', reg_id) 103 | reg_id = self.fetch_regressor_info('on_across_by', reg_id) 104 | reg_id = self.fetch_regressor_info('A', reg_id) 105 | reg_id = self.fetch_regressor_info('B', reg_id) 106 | reg_id = self.fetch_regressor_info('X', reg_id) 107 | reg_id = self.fetch_regressor_info('ABX', reg_id) 108 | for field in ['by', 'on_across_by', 'A', 'B', 'X', 'ABX']: 109 | names = names + \ 110 | [name for name_list in getattr( 111 | self, field + '_names') for name in name_list] 112 | for dictionary in getattr(self, field + '_indexes'): 113 | for key, index in iteritems(dictionary): 114 | indexes[key] = index 115 | return names, indexes 116 | 117 | def fetch_regressor_info(self, field, reg_id): 118 | setattr(self, field + '_names', []) 119 | setattr(self, field + '_indexes', []) 120 | for db_fun in getattr(self, field): 121 | nb_o, o_names, o_indexes = db_fun.output_specs() 122 | if o_names is None: # give arbitrary names 123 | o_names = ['reg_' + str(reg_id + n) for n in range(nb_o)] 124 | reg_id = reg_id + nb_o 125 | getattr(self, field + '_names').append(o_names) 126 | getattr(self, field + '_indexes').append(o_indexes) 127 | return reg_id 128 | -------------------------------------------------------------------------------- /ABXpy/analyze.py: -------------------------------------------------------------------------------- 1 | """This module is used to analyse the results of an ABX discrimination task 2 | 3 | It collapses the result and give the mean score for each block of triplets 4 | sharing the same on, across and by labels. It output a tab separated csv file 5 | which columns are the relevant labels, the average score and the number of 6 | triplets in the block. See `Files format `_ for a more 7 | in-depth explanation. 8 | 9 | It requires a score file and a task file. 10 | 11 | Usage 12 | ----- 13 | Form the command line: 14 | 15 | .. code-block:: bash 16 | 17 | python analyze.py data.score data.abx data.csv 18 | 19 | In python: 20 | 21 | .. code-block:: python 22 | 23 | import ABXpy.analyze 24 | # Prerequisite: calculate a task data.abx, and a score data.score 25 | ABXpy.analyze.analyze(data.score, data.abx, data.csv) 26 | 27 | """ 28 | 29 | import h5py 30 | import numpy as np 31 | import argparse 32 | import os.path as path 33 | import os 34 | import warnings 35 | 36 | from ABXpy.misc.type_fitting import fit_integer_type 37 | 38 | 39 | def npdecode(keys, max_ind): 40 | """Vectorized implementation of the decoding of the labels: 41 | i = (a1*n2 + a2)*n3 + a3 ... 42 | """ 43 | res = np.empty((len(keys), len(max_ind))) 44 | aux = keys 45 | k = len(max_ind) 46 | for i in range(k - 1): 47 | res[:, k - 1 - i] = np.mod(aux, max_ind[k - 1 - i]) 48 | aux = np.divide(aux - res[:, k - 1 - i], max_ind[k - 1 - i]) 49 | res[:, 0] = aux 50 | return res 51 | 52 | 53 | def unique_rows(arr): 54 | """Numpy unique applied to the row only""" 55 | return (np.unique(np.ascontiguousarray(arr) 56 | .view(np.dtype((np.void, 57 | arr.dtype.itemsize * arr.shape[1])))) 58 | .view(arr.dtype).reshape(-1, arr.shape[1])) 59 | 60 | 61 | def collapse(scorefile, taskfile, fid): 62 | """Collapses the results for each triplets sharing the same on, across 63 | and by labels. 64 | 65 | """ 66 | # We make the assumption that everything fits in memory... 67 | scorefid = h5py.File(scorefile, 'r+') 68 | taskfid = h5py.File(taskfile, 'r') 69 | bys = taskfid['bys'][...] 70 | for by_idx, by in enumerate(bys): 71 | # print 'collapsing {0}/{1}'.format(by_idx + 1, len(bys)) 72 | trip_attrs = taskfid['triplets']['by_index'][by_idx] 73 | 74 | tfrk = taskfid['regressors'][by] 75 | 76 | tmp = tfrk[u'indexed_data'] 77 | indices = np.array(tmp) 78 | if indices.size == 0: 79 | continue 80 | tmp = scorefid['scores'][trip_attrs[0]:trip_attrs[1]] 81 | scores_arr = np.array(tmp) 82 | tmp = np.ascontiguousarray(indices).view( 83 | np.dtype((np.void, indices.dtype.itemsize * indices.shape[1]))) 84 | n_indices = np.max(indices, 0) + 1 85 | assert np.prod(n_indices) < 18446744073709551615, "type not big enough" 86 | ind_type = fit_integer_type(np.prod(n_indices), 87 | is_signed=False) 88 | # encoding the indices of a triplet to a unique index 89 | new_index = indices[:, 0].astype(ind_type) 90 | for i in range(1, len(n_indices)): 91 | new_index = indices[:, i] + n_indices[i] * new_index 92 | 93 | permut = np.argsort(new_index) 94 | 95 | # collapsing the score 96 | sorted_scores = scores_arr[permut] 97 | sorted_index = new_index[permut] 98 | mean, unique_index, counts = unique(sorted_index, sorted_scores) 99 | 100 | # retrieving the triplet indices from the unique index. 101 | tmp = npdecode(unique_index, n_indices) 102 | 103 | regs = tfrk['indexed_datasets'] 104 | indexes = [] 105 | for reg in regs: 106 | indexes.append(tfrk['indexes'][reg][:]) 107 | nregs = len(regs) 108 | 109 | for i, key in enumerate(tmp): 110 | aux = list() 111 | for j in range(nregs): 112 | aux.append(indexes[j][int(key[j])]) 113 | score = mean[i] 114 | n = counts[i] 115 | result = aux + [by, score, int(n)] 116 | fid.write('\t'.join(map(str, result)) + u'\n') 117 | # results.append(aux + [context, score, n]) 118 | # wf_tmp.write('\t'.join(map(str, results[-1])) + '\n') 119 | scorefid.close() 120 | taskfid.close() 121 | del taskfid 122 | # wf_tmp.close() 123 | # return results 124 | 125 | 126 | def unique(index, scores): 127 | flag = np.concatenate( 128 | ([True], index[1:] != index[:-1], [True])) 129 | unique_idx = np.nonzero(flag)[0] 130 | counts = unique_idx[1:] - unique_idx[:-1] 131 | unique_index = index[unique_idx[:-1]] 132 | 133 | means = np.empty((len(counts)),) 134 | i = 0 135 | for a, c in enumerate(counts): 136 | means[a] = np.mean(scores[i:i+c]) 137 | i += c 138 | means = (means + 1) / 2 139 | return means, unique_index, counts 140 | 141 | 142 | def analyze(task_file, score_file, result_file): 143 | """Analyse the results of a task 144 | 145 | Parameters 146 | ---------- 147 | task_file : string, hdf5 file 148 | the file containing the triplets and pairs of the task 149 | score_file : string, hdf5 file 150 | the file containing the score of a task 151 | result_file: string, csv file 152 | the file that will contain the analysis results 153 | 154 | """ 155 | with open(result_file, 'w+') as fid: 156 | taskfid = h5py.File(task_file, 'r') 157 | aux = taskfid['regressors'] 158 | tfrk = aux[list(aux)[0]] 159 | regs = tfrk['indexed_datasets'] 160 | string = u'' 161 | for reg in regs: 162 | string += reg + '\t' 163 | string += 'by\tscore\tn\n' 164 | fid.write(string) 165 | taskfid.close() 166 | collapse(score_file, task_file, fid) 167 | 168 | 169 | def parse_args(): 170 | parser = argparse.ArgumentParser( 171 | formatter_class=argparse.RawDescriptionHelpFormatter, 172 | description='Collapse results of ABX score on type of ABX triplet.', 173 | epilog='''Example usage: 174 | 175 | $ ./analyze.py abx.score abx.task abx.csv 176 | 177 | compute the average the scores in abx.score by type of ABX triplet 178 | and output the results in tab separated csv format.''') 179 | parser.add_argument('scorefile', metavar='SCORE', 180 | help='score file in hdf5 format') 181 | parser.add_argument('taskfile', metavar='TASK', 182 | help='task file in hdf5 format') 183 | parser.add_argument('output', metavar='OUTPUT', 184 | help='output file in csv format') 185 | return vars(parser.parse_args()) 186 | 187 | 188 | def main(): 189 | args = parse_args() 190 | 191 | score_file = args['scorefile'] 192 | if not path.exists(score_file): 193 | print('No such file: {}'.format(score_file)) 194 | exit() 195 | 196 | task_file = args['taskfile'] 197 | if not path.exists(task_file): 198 | print('No such file: {}'.format(task_file)) 199 | exit() 200 | 201 | result_file = args['output'] 202 | if os.path.exists(result_file): 203 | warnings.warn( 204 | 'Overwriting results file ' + args['output'], UserWarning) 205 | os.remove(result_file) 206 | 207 | analyze(task_file, score_file, result_file) 208 | 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /ABXpy/dbfun/dbfun_compute.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Oct 14 16:59:27 2013 4 | 5 | @author: Thomas Schatz 6 | 7 | Class for defining and computing efficiently functions of the columns of a 8 | database. 9 | Implements the DBfun API 10 | """ 11 | 12 | import ast 13 | import sys 14 | 15 | # Only solution I found for circular 16 | # imports in both Python 2 and 3 17 | from . import * 18 | from . import dbfun 19 | 20 | 21 | # FIXME remove dbfun prefix from dbfun_lookuptable and dbfun_connector ? 22 | class DBfun_Compute(dbfun.DBfun): 23 | def __init__(self, definition, columns): 24 | 25 | self.columns = set(columns) 26 | # set script 27 | if len(definition) >= 3 and definition[-3:] == '.py': 28 | with open(definition) as script_file: 29 | self.script = script_file.read() 30 | else: 31 | self.script = definition 32 | self.parse() 33 | # FIXME allow users to specify the content of n_outputs and/or 34 | # output_names from command-line (currently it will always be 1, 35 | # None...) 36 | self.n_outputs = 1 37 | self.output_names = None 38 | 39 | # slow but not likely to be critical 40 | def parse(self): 41 | """ 42 | first separate the script defining the function into various components 43 | (import statements, with 'h5file' statement, synopsis definition, main 44 | code) 45 | """ 46 | tree = ast.parse(self.script) 47 | # find and extract imports 48 | imports = [stat for stat in tree.body if isinstance( 49 | stat, (ast.ImportFrom, ast.Import))] 50 | rest = [stat for stat in tree.body if not( 51 | isinstance(stat, (ast.ImportFrom, ast.Import)))] 52 | 53 | # store ast with import statements 54 | # can be executed later using: exec(self.import_bytecode) 55 | if sys.version_info >= (3, 8): # ast.Module spec changed in python-3.8 56 | self.import_bytecode = compile(ast.Module(imports, []), '', mode='exec') 57 | tree = ast.Module(rest, []) 58 | else: 59 | self.import_bytecode = compile(ast.Module(imports), '', mode='exec') 60 | tree = ast.Module(rest) 61 | 62 | # look for a with statement with a string, store context info and 63 | # remove with statement 64 | tree = self.process_with(tree) 65 | # check that last line is an expression 66 | expression = tree.body[-1] 67 | if not(isinstance(expression, ast.Expr)): 68 | raise ValueError( 69 | 'The following script should finish by an expression: %s' 70 | % self.script) 71 | # store what is left 72 | self.main_ast = tree 73 | """ 74 | second find column names in the main code so as to determine what 75 | context is used and check coherence 76 | """ 77 | # find the list of the names of the variables in the main ast 78 | visitor = nameVisitor() # see definition of this class below 79 | visitor.visit(self.main_ast) 80 | names = set(visitor.names) 81 | # FIXME could add a check that all names correspond either to a bound 82 | # variable or is in self.columns 83 | # need a way to get a list of unbound variable ???? 84 | # then would raise ValueError('There are unbound variables in script 85 | # %s' % self.script) 86 | # For now: just consider that the inputs are the intersection of the 87 | # element of names and of self.columns 88 | # FIXME document that this means that using local variables with the 89 | # same name as db_columns in the scripts will affect the synopsis of 90 | # the dbfun ... 91 | self.input_names = list(names.intersection(self.columns)) 92 | """ 93 | third parse final expression and additional code asts for nodes that 94 | involve aliases of the aux_files, get the hierarchy of calls 95 | to aux_files, check that it is flat, compile the corresponding partial 96 | bytecodes and collect all the info necessary for efficient evaluation 97 | of the code 98 | """ 99 | if self.aux_functions: 100 | connector = lookuptable_connector.LookupTableConnector( 101 | self.script, self.aliases) 102 | connector.visit(self.main_ast) 103 | # self.code_nodes contains a list of dictionaries containing the 104 | # info for the various calls to function defined in auxiliary h5 105 | # files 106 | # each dictionry contains entries: 'child_asts', 'varname' and 107 | # 'function' and optionnally 'output_cols' 108 | self.code_nodes = self.main_ast.code_nodes 109 | # clean the main ast 110 | del(self.main_ast.code_nodes) 111 | else: 112 | self.code_nodes = [] 113 | # compile all asts, with final expr apart 114 | # without the .value you only get a ast.Expr instead of getting the 115 | # actual expr 116 | self.final_ast = ast.Expression(self.main_ast.body[-1].value) 117 | self.main_ast.body = self.main_ast.body[:-1] 118 | for dico in self.code_nodes: 119 | bytecodes = [] 120 | for child in dico['child_asts']: 121 | bytecodes.append(compile(child, '', mode='eval')) 122 | dico['child_bytecodes'] = bytecodes 123 | self.main_bytecode = compile(self.main_ast, '', mode='exec') 124 | self.final_bytecode = compile(self.final_ast, '', mode='eval') 125 | 126 | # just an auxiliary function for parse, dealing with 'with h5file' 127 | # statements 128 | def process_with(self, tree): 129 | self.aux_files = [] 130 | self.aliases = [] 131 | self.aux_functions = [] 132 | withs = [stat for stat in tree.body if isinstance(stat, ast.With)] 133 | kept = [] 134 | for i, w in enumerate(withs): 135 | if isinstance(w.context_expr, ast.Str): 136 | kept.append((i, w)) 137 | if len(kept) > 1: 138 | raise ValueError( 139 | 'There is more than one with statement for re-using auxiliary' 140 | ' ABX files in script: %s' % self.script) 141 | if len(kept) == 1: 142 | # find the h5 files and aliases 143 | s = kept[0][1] 144 | while (isinstance(s, ast.With) and 145 | isinstance(w.context_expr, ast.Str)): 146 | self.aux_files.append(s.context_expr.s) 147 | self.aliases.append(s.optional_vars.id) 148 | s = s.body[0] 149 | # remove with statement from ast 150 | stats = [] 151 | with_i = 0 152 | for stat in tree.body: 153 | if isinstance(stat, ast.With): 154 | if with_i == kept[0][0]: 155 | stats.append(s) 156 | else: 157 | stats.append(stat) 158 | with_i = with_i + 1 159 | else: 160 | stats.append(stat) 161 | tree = ast.Module(stats) 162 | # instantiate corresponding DBfun_LookupTables: 163 | for f in self.aux_files: 164 | self.aux_functions.append( 165 | dbfun_lookuptable.DBfun_LookupTable(f, indexed=False)) 166 | return tree 167 | 168 | # FIXME if there is any sense in having indexed outputs for dbfun_compute, 169 | # implement it 170 | def output_specs(self): 171 | return self.n_outputs, self.output_names, {} 172 | 173 | # function for evaluating the column function given data for the context 174 | # context is a dictionary with just the right name/content associations 175 | def evaluate(self, context): 176 | # set up context 177 | ns_local = context 178 | ns_global = {} 179 | # exec imports in that context 180 | exec(self.import_bytecode, ns_global, ns_local) 181 | # evaluate the calls to aux functions 182 | for node in self.code_nodes: 183 | # evaluate the arguments to the call 184 | aux_context = {} 185 | args = node['function'].in_names 186 | for code, arg in zip(node['child_bytecodes'], args): 187 | aux_context[arg] = eval(code, ns_global, ns_local) 188 | # call the aux function and assign it in the main namespace 189 | ns_local[node['varname']] = node['function'].evaluate(aux_context) 190 | # FIXME if aux files, could use the output_cols here ? and maybe 191 | # need to do it also in direct case for consistency ? 192 | # also is output format for vlen output going to work ? 193 | # exec main_bytecode 194 | exec(self.main_bytecode, ns_global, ns_local) 195 | return eval(self.final_bytecode, ns_global, ns_local) 196 | 197 | 198 | # visitor class for getting the list of the names of the variables in expr 199 | # (minus the import statements) 200 | class nameVisitor(ast.NodeVisitor): 201 | def __init__(self, *args, **kwargs): 202 | ast.NodeVisitor.__init__(self, *args, **kwargs) 203 | self.names = [] 204 | 205 | def visit_Name(self, node): 206 | self.names.append(node.id) 207 | -------------------------------------------------------------------------------- /test/test_task.py: -------------------------------------------------------------------------------- 1 | """This test script contains tests for the basic parameters of task.py""" 2 | 3 | import h5py 4 | import numpy as np 5 | import os 6 | import warnings 7 | 8 | import ABXpy.task 9 | import ABXpy.misc.items as items 10 | 11 | error_pairs = "pairs incorrectly generated" 12 | error_triplets = "triplets incorrectly generated" 13 | 14 | 15 | # not optimized, but unimportant 16 | def tables_equivalent(t1, t2): 17 | assert t1.shape == t2.shape 18 | for a1 in t1: 19 | res = False 20 | for a2 in t2: 21 | if np.array_equal(a1, a2): 22 | res = True 23 | if not res: 24 | return False 25 | return True 26 | 27 | 28 | def get_triplets(hdf5file, by): 29 | triplet_db = hdf5file['triplets'] 30 | triplets = triplet_db['data'] 31 | by_index = list(hdf5file['bys']).index(by) 32 | triplets_index = triplet_db['by_index'][by_index] 33 | return triplets[slice(*triplets_index)] 34 | 35 | 36 | def get_pairs(hdf5file, by): 37 | pairs_db = hdf5file['unique_pairs'] 38 | pairs = pairs_db['data'] 39 | pairs_index = pairs_db.attrs[by][1:3] 40 | return pairs[slice(*pairs_index)] 41 | 42 | 43 | # test1, triplets and pairs verification 44 | def test_basic(): 45 | items.generate_testitems(2, 3, name='data.item') 46 | try: 47 | task = ABXpy.task.Task('data.item', 'c0', 'c1', 'c2') 48 | stats = task.stats 49 | assert stats['nb_blocks'] == 8, "incorrect stats: number of blocks" 50 | assert stats['nb_triplets'] == 8 51 | assert stats['nb_by_levels'] == 2 52 | task.generate_triplets() 53 | f = h5py.File('data.abx', 'r') 54 | triplets = f['triplets']['data'][...] 55 | by_indexes = f['triplets']['by_index'][...] 56 | triplets_block0 = triplets[slice(*by_indexes[0])] 57 | triplets_block1 = triplets[slice(*by_indexes[1])] 58 | triplets_block0 = get_triplets(f, '0') 59 | triplets_block1 = get_triplets(f, '1') 60 | triplets = np.array([[0, 1, 2], [1, 0, 3], [2, 3, 0], [3, 2, 1]]) 61 | assert tables_equivalent(triplets, triplets_block0), error_triplets 62 | assert tables_equivalent(triplets, triplets_block1), error_triplets 63 | pairs = [2, 6, 7, 3, 8, 12, 13, 9] 64 | pairs_block0 = get_pairs(f, '0') 65 | pairs_block1 = get_pairs(f, '1') 66 | assert (set(pairs) == set(pairs_block0[:, 0])), error_pairs 67 | assert (set(pairs) == set(pairs_block1[:, 0])), error_pairs 68 | finally: 69 | try: 70 | os.remove('data.abx') 71 | os.remove('data.item') 72 | except OSError: 73 | pass 74 | 75 | 76 | # testing with a list of across attributes, triplets verification 77 | def test_multiple_across(): 78 | items.generate_testitems(2, 3, name='data.item') 79 | try: 80 | with warnings.catch_warnings(): 81 | warnings.simplefilter("ignore") 82 | task = ABXpy.task.Task('data.item', 'c0', ['c1', 'c2']) 83 | 84 | stats = task.stats 85 | assert stats['nb_blocks'] == 8 86 | assert stats['nb_triplets'] == 8 87 | assert stats['nb_by_levels'] == 1 88 | 89 | task.generate_triplets() 90 | 91 | f = h5py.File('data.abx', 'r') 92 | triplets_block = get_triplets(f, '0') 93 | triplets = np.array([[0, 1, 6], [1, 0, 7], [2, 3, 4], [3, 2, 5], 94 | [4, 5, 2], [5, 4, 3], [6, 7, 0], [7, 6, 1]]) 95 | assert tables_equivalent(triplets, triplets_block) 96 | finally: 97 | try: 98 | os.remove('data.abx') 99 | os.remove('data.item') 100 | except OSError: 101 | pass 102 | 103 | 104 | # testing without any across attribute 105 | def test_no_across(): 106 | items.generate_testitems(2, 3, name='data.item') 107 | try: 108 | task = ABXpy.task.Task('data.item', 'c0', None, 'c2') 109 | stats = task.stats 110 | assert stats['nb_blocks'] == 8 111 | assert stats['nb_triplets'] == 16 112 | assert stats['nb_by_levels'] == 2 113 | task.generate_triplets() 114 | finally: 115 | try: 116 | os.remove('data.abx') 117 | os.remove('data.item') 118 | except OSError: 119 | pass 120 | 121 | 122 | # testing for multiple by attributes, asserting the statistics 123 | def test_multiple_bys(): 124 | items.generate_testitems(3, 4, name='data.item') 125 | try: 126 | task = ABXpy.task.Task('data.item', 'c0', None, ['c1', 'c2', 'c3']) 127 | stats = task.stats 128 | assert stats['nb_blocks'] == 81 129 | assert stats['nb_triplets'] == 0 130 | assert stats['nb_by_levels'] == 27 131 | with warnings.catch_warnings(): 132 | warnings.simplefilter("ignore") 133 | task.generate_triplets() 134 | finally: 135 | try: 136 | os.remove('data.abx') 137 | os.remove('data.item') 138 | except OSError: 139 | pass 140 | 141 | 142 | # testing for a general filter (discarding last column) 143 | def test_filter(): 144 | items.generate_testitems(2, 4, name='data.item') 145 | try: 146 | task = ABXpy.task.Task('data.item', 'c0', 'c1', 'c2', 147 | filters=["[attr == 0 for attr in c3]"]) 148 | stats = task.stats 149 | assert stats['nb_blocks'] == 8, "incorrect stats: number of blocks" 150 | assert stats['nb_triplets'] == 8 151 | assert stats['nb_by_levels'] == 2 152 | task.generate_triplets(output='data.abx') 153 | f = h5py.File('data.abx', 'r') 154 | triplets_block0 = get_triplets(f, '0') 155 | triplets_block1 = get_triplets(f, '1') 156 | triplets = np.array([[0, 1, 2], [1, 0, 3], [2, 3, 0], [3, 2, 1]]) 157 | assert tables_equivalent(triplets, triplets_block0), error_triplets 158 | assert tables_equivalent(triplets, triplets_block1), error_triplets 159 | pairs = [2, 6, 7, 3, 8, 12, 13, 9] 160 | pairs_block0 = get_pairs(f, '0') 161 | pairs_block1 = get_pairs(f, '1') 162 | assert (set(pairs) == set(pairs_block0[:, 0])), error_pairs 163 | assert (set(pairs) == set(pairs_block1[:, 0])), error_pairs 164 | finally: 165 | try: 166 | os.remove('data.abx') 167 | os.remove('data.item') 168 | except OSError: 169 | pass 170 | 171 | 172 | # testing with simple filter on A, verifying triplet generation 173 | def test_filter_on_A(): 174 | items.generate_testitems(2, 2, name='data.item') 175 | try: 176 | task = ABXpy.task.Task('data.item', 'c0', 177 | filters=["[attr == 0 for attr in c0_A]"]) 178 | stats = task.stats 179 | assert stats['nb_blocks'] == 4, "incorrect stats: number of blocks" 180 | assert stats['nb_triplets'] == 4 181 | assert stats['nb_by_levels'] == 1 182 | task.generate_triplets() 183 | f = h5py.File('data.abx', 'r') 184 | triplets_block0 = get_triplets(f, '0') 185 | triplets = np.array([[0, 1, 2], [0, 3, 2], [2, 1, 0], [2, 3, 0]]) 186 | assert tables_equivalent(triplets, triplets_block0), error_triplets 187 | finally: 188 | try: 189 | os.remove('data.abx') 190 | os.remove('data.item') 191 | except OSError: 192 | pass 193 | 194 | 195 | # testing with simple filter on B, verifying triplet generation 196 | def test_filter_on_B(): 197 | items.generate_testitems(2, 2, name='data.item') 198 | try: 199 | task = ABXpy.task.Task('data.item', 'c0', 200 | filters=["[attr == 0 for attr in c1_B]"]) 201 | stats = task.stats 202 | assert stats['nb_blocks'] == 4, "incorrect stats: number of blocks" 203 | assert stats['nb_triplets'] == 4 204 | assert stats['nb_by_levels'] == 1 205 | task.generate_triplets() 206 | f = h5py.File('data.abx', 'r') 207 | triplets_block0 = get_triplets(f, '0') 208 | triplets = np.array([[0, 1, 2], [1, 0, 3], [2, 1, 0], [3, 0, 1]]) 209 | assert tables_equivalent(triplets, triplets_block0), error_triplets 210 | finally: 211 | try: 212 | os.remove('data.abx') 213 | os.remove('data.item') 214 | except OSError: 215 | pass 216 | 217 | 218 | # testing with simple filter on B, verifying triplet generation 219 | def test_filter_on_C(): 220 | items.generate_testitems(2, 2, name='data.item') 221 | try: 222 | task = ABXpy.task.Task('data.item', 223 | 'c0', 224 | filters=["[attr == 0 for attr in c1_X]"]) 225 | stats = task.stats 226 | assert stats['nb_blocks'] == 4, "incorrect stats: number of blocks" 227 | assert stats['nb_triplets'] == 4 228 | assert stats['nb_by_levels'] == 1 229 | task.generate_triplets() 230 | f = h5py.File('data.abx', 'r') 231 | triplets_block0 = get_triplets(f, '0') 232 | triplets = np.array([[2, 1, 0], [2, 3, 0], [3, 0, 1], [3, 2, 1]]) 233 | assert tables_equivalent(triplets, triplets_block0), error_triplets 234 | finally: 235 | try: 236 | os.remove('data.abx') 237 | os.remove('data.item') 238 | except OSError: 239 | pass 240 | -------------------------------------------------------------------------------- /ABXpy/h5tools/np2h5.py: -------------------------------------------------------------------------------- 1 | """Class for efficiently writing to disk (in a dataset of a HDF5 file) 2 | 3 | Simple two-dimensional numpy arrays that are incrementally generated 4 | along the first dimension. It uses buffers to avoid small I/O. 5 | 6 | It needs to be used within a 'with' statement, so as to handle buffer 7 | flushing and opening and closing of the underlying HDF5 file smoothly. 8 | 9 | Buffer size should be chosen according to speed/memory trade-off. Due 10 | to cache issues there is probably an optimal size. 11 | 12 | The size of the dataset to be written must be known in advance, 13 | excepted when overwriting an existing dataset. Not writing exactly 14 | the expected amount of data causes an Exception to be thrown excepted 15 | is the fixed_size option was set to False when adding the dataset. 16 | 17 | """ 18 | import numpy as np 19 | import h5py 20 | 21 | 22 | class NP2H5(object): 23 | # sink is the name of the HDF5 file to which to write, buffer size is in 24 | # kilobytes 25 | 26 | def __init__(self, h5file): 27 | # set up output file and buffer list 28 | if isinstance(h5file, str): 29 | self.manage_file = True 30 | self.filename = h5file 31 | self.file_open = False 32 | else: # supposed to be a h5 file handle 33 | self.manage_file = False 34 | self.file_open = True 35 | self.file = h5file 36 | self.filename = h5file.filename 37 | self.buffers = [] 38 | 39 | # open HDF5 file in 'with' statement 40 | def __enter__(self): 41 | if not(self.file_open): 42 | self.file = h5py.File(self.filename, 'a') 43 | self.file_open = True 44 | return self 45 | 46 | # flush buffers and close HDF5 file in 'with' statement 47 | def __exit__(self, eType, eValue, eTrace): 48 | try: 49 | if self.file_open: 50 | for buf in self.buffers: 51 | buf.flush() 52 | if self.manage_file: 53 | self.file.close() 54 | self.file_open = False 55 | # if there was an error, delete dataset, otherwise check that the 56 | # amount of data actually written is consistent with the size of 57 | # the datasets 58 | if eValue is not None: 59 | if not(self.file_open): 60 | self.file = h5py.File(self.filename, 'a') 61 | self.file_open = True 62 | for buf in self.buffers: 63 | buf.delete() 64 | if self.manage_file: 65 | self.file.close() 66 | self.file_open = False 67 | # check that all buffers were completed (defaults to true for 68 | # dataset without a fixed size) 69 | elif not(all([buf.iscomplete() for buf in self.buffers])): 70 | raise Warning( 71 | 'File %s, the amount of data actually written is not consistent with the size of the datasets' % self.filename) 72 | except: 73 | # raise the first exception 74 | if eValue is not None: 75 | # FIXME the first exception will be raised, but could log a 76 | # warning here ... 77 | pass 78 | else: 79 | raise 80 | 81 | def add_dataset(self, group, dataset, n_rows=0, n_columns=None, chunk_size=10, buf_size=100, item_type=np.int64, overwrite=False, fixed_size=True): 82 | if n_columns is None: 83 | raise ValueError( 84 | 'You have to specify the number of columns of the dataset.') 85 | if self.file_open: 86 | buf = NP2H5buffer(self, group, dataset, n_rows, n_columns, 87 | chunk_size, buf_size, item_type, overwrite, fixed_size) 88 | self.buffers.append(buf) 89 | return buf 90 | else: 91 | raise IOError( 92 | "Method add_dataset of class NP2H5 can only be used within a 'with' statement!") 93 | 94 | 95 | class NP2H5buffer(object): 96 | 97 | # buf_size in Ko 98 | 99 | def __init__(self, parent, group, dataset, n_rows, n_columns, chunk_size, buf_size, item_type, overwrite, fixed_size): 100 | 101 | assert parent.file_open 102 | 103 | # check coherency of arguments if size is fixed or not 104 | if n_rows == 0 and fixed_size: 105 | raise ValueError( 106 | 'A dataset with a fixed size cannot have zero lines') 107 | if overwrite and not(fixed_size): 108 | raise ValueError( 109 | 'Cannot overwrite a dataset without a specified fixed size') 110 | self.fixed_size = fixed_size 111 | 112 | # check type argument 113 | # dtype call is needed to access the itemsize attribute in case a 114 | # built-in type was specified 115 | self.type = np.dtype(item_type) 116 | if self.type.itemsize == 0: 117 | raise AttributeError( 118 | 'NP2H5 can only be used with numpy arrays whose items have a fixed size in memory') 119 | 120 | # initialize buffer 121 | self.buf_len = nb_lines(self.type.itemsize, n_columns, buf_size) 122 | self.buf = np.zeros([self.buf_len, n_columns], dtype=self.type) 123 | self.buf_ix = 0 124 | 125 | # set up output dataset 126 | self.dataset_ix = 0 127 | # fail if dataset already exists and overwrite=False otherwise create 128 | # it or overwrite it 129 | if group + '/' + dataset in parent.file: 130 | if overwrite: 131 | self.dataset = parent.file[group][dataset] 132 | if self.dataset.shape[0] != n_rows or self.dataset.shape[1] != n_columns or self.dataset.dtype != self.type: 133 | raise IOError( 134 | 'Overwriting a dataset is only possible if it already has the correct shape and dtype') 135 | else: 136 | raise IOError( 137 | 'Dataset %s already exists in file %s!' % (dataset, parent.filename)) 138 | else: 139 | # if necessary create group 140 | try: 141 | g = parent.file[group] 142 | except KeyError: 143 | g = parent.file.create_group(group) 144 | # create dataset 145 | if self.fixed_size: 146 | # would it be useful to chunk here? 147 | g.create_dataset(dataset, (n_rows, n_columns), dtype=self.type) 148 | else: 149 | chunk_lines = nb_lines( 150 | self.type.itemsize, n_columns, chunk_size) 151 | g.create_dataset(dataset, (n_rows, n_columns), dtype=self.type, chunks=( 152 | chunk_lines, n_columns), maxshape=(None, n_columns)) 153 | self.dataset = parent.file[group][dataset] 154 | 155 | # store useful parameters 156 | self.n_rows = n_rows 157 | self.n_columns = n_columns 158 | self.parent = parent 159 | 160 | def write(self, data): 161 | # fail if not used in a with statement of a parent NP2H5 object 162 | if not(self.parent.file_open): 163 | raise IOError( 164 | "Method write of class NP2H5buffer can only be used within a 'with' statement of parent NP2H5 object!") 165 | target_ix = self.buf_ix + data.shape[0] 166 | # if size is not of fixed size, check that it is big enough 167 | if not(self.fixed_size): 168 | necessary_rows = self.dataset_ix + \ 169 | self.buf_len * (target_ix // self.buf_len) 170 | if necessary_rows > self.n_rows: 171 | self.n_rows = necessary_rows 172 | # maybe should use larger increments ? could use chunk size as 173 | # a basis for the increments instead of buf_len if useful 174 | self.dataset.resize((self.n_rows, self.n_columns)) 175 | # while buffer is full dump it to file 176 | while target_ix >= self.buf_len: 177 | # fill buffer 178 | buffer_space = self.buf_len - self.buf_ix 179 | self.buf[self.buf_ix:] = data[:buffer_space, :] 180 | # dump buffer to file 181 | ix_start = self.dataset_ix 182 | ix_end = self.dataset_ix + self.buf_len 183 | self.dataset[ix_start:ix_end, :] = self.buf 184 | self.dataset_ix = ix_end 185 | # reset variables for next iteration 186 | self.buf_ix = 0 187 | data = data[buffer_space:, :] 188 | target_ix = target_ix - self.buf_len 189 | # put remaining data in buffer 190 | self.buf[self.buf_ix:target_ix, :] = data 191 | self.buf_ix = target_ix 192 | 193 | def flush(self): 194 | assert self.parent.file_open 195 | if self.buf_ix > 0: 196 | ix_start = self.dataset_ix 197 | ix_end = self.dataset_ix + self.buf_ix 198 | if not(self.fixed_size) and ix_end > self.n_rows: 199 | self.dataset.resize((ix_end, self.n_columns)) 200 | self.dataset[ix_start:ix_end, :] = self.buf[:self.buf_ix] 201 | self.dataset_ix = ix_end 202 | self.buf_ix = 0 203 | 204 | def delete(self): 205 | assert self.parent.file_open 206 | del self.dataset 207 | 208 | def iscomplete(self): 209 | if self.fixed_size: 210 | test = self.dataset_ix == self.n_rows 211 | else: 212 | test = True 213 | return test 214 | 215 | 216 | # item_size given in bytes, size_in_mem given in kilobytes 217 | def nb_lines(item_size, n_columns, size_in_mem): 218 | return int(round(size_in_mem * 1000. / (item_size * n_columns))) 219 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # ABXpy documentation build configuration file, created by 4 | # sphinx-quickstart on Wed May 7 22:50:13 2014. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import ABXpy 16 | import datetime 17 | import os 18 | import sys 19 | 20 | # mocking for ReadTheDoc 21 | from mock import Mock as MagicMock 22 | 23 | # If extensions (or modules to document with autodoc) are in another directory, 24 | # add these directories to sys.path here. If the directory is relative to the 25 | # documentation root, use os.path.abspath to make it absolute, like shown here. 26 | sys.path.insert(0, os.path.abspath('.')) 27 | sys.path.insert( 28 | 0, os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 29 | 30 | 31 | class Mock(MagicMock): 32 | @classmethod 33 | def __getattr__(cls, name): 34 | return Mock() 35 | 36 | 37 | MOCK_MODULES = ['tables', 'scipy', 'numpy', 'pandas', 'h5py', 'h5features'] 38 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 39 | 40 | 41 | # -- General configuration ----------------------------------------------- 42 | 43 | # If your documentation needs a minimal Sphinx version, state it here. 44 | # needs_sphinx = '1.0' 45 | 46 | # Add any Sphinx extension module names here, as strings. They can be 47 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 48 | # ones. 49 | extensions = ['sphinx.ext.autodoc', 50 | 'sphinx.ext.viewcode', 51 | 'sphinx.ext.autosummary', 52 | 'numpydoc'] 53 | 54 | # Add any paths that contain templates here, relative to this directory. 55 | templates_path = ['_templates'] 56 | 57 | 58 | # The encoding of source files. 59 | # source_encoding = 'utf-8-sig' 60 | 61 | # The master toctree document. 62 | master_doc = 'index' 63 | 64 | # General information about the project. 65 | project = 'ABXpy' 66 | copyright = '2014 - {}, CoML team'.format(datetime.datetime.now().year) 67 | 68 | # The suffix of source filenames. 69 | source_suffix = '.rst' 70 | 71 | rst_epilog = '.. |copyright| replace:: %s' % copyright 72 | 73 | # The version info for the project you're documenting, acts as replacement for 74 | # |version| and |release|, also used in various other places throughout the 75 | # built documents. 76 | # 77 | # The short X.Y version. 78 | version = ABXpy.version 79 | # The full version, including alpha/beta/rc tags. 80 | release = version 81 | 82 | # The language for content autogenerated by Sphinx. Refer to documentation 83 | # for a list of supported languages. 84 | # language = None 85 | 86 | # There are two options for replacing |today|: either, you set today to some 87 | # non-false value, then it is used: 88 | # today = '' 89 | # Else, today_fmt is used as the format for a strftime call. 90 | # today_fmt = '%B %d, %Y' 91 | 92 | # List of patterns, relative to source directory, that match files and 93 | # directories to ignore when looking for source files. 94 | exclude_patterns = ['_build'] 95 | 96 | # The reST default role (used for this markup: `text`) to use for all 97 | # documents. 98 | # default_role = None 99 | 100 | # If true, '()' will be appended to :func: etc. cross-reference text. 101 | # add_function_parentheses = True 102 | 103 | # If true, the current module name will be prepended to all description 104 | # unit titles (such as .. function::). 105 | # add_module_names = True 106 | 107 | # If true, sectionauthor and moduleauthor directives will be shown in the 108 | # output. They are ignored by default. 109 | # show_authors = False 110 | 111 | # The name of the Pygments (syntax highlighting) style to use. 112 | pygments_style = 'sphinx' 113 | 114 | # A list of ignored prefixes for module index sorting. 115 | # modindex_common_prefix = [] 116 | 117 | 118 | # -- Options for HTML output --------------------------------------------- 119 | 120 | # The theme to use for HTML and HTML Help pages. See the documentation for 121 | # a list of builtin themes. 122 | html_theme = "sphinx_rtd_theme" 123 | 124 | # Theme options are theme-specific and customize the look and feel of a theme 125 | # further. For a list of options available for each theme, see the 126 | # documentation. 127 | #html_theme_options = {} 128 | 129 | # Add any paths that contain custom themes here, relative to this directory. 130 | #html_theme_path = [] 131 | 132 | # The name for this set of Sphinx documents. If None, it defaults to 133 | # " v documentation". 134 | #html_title = None 135 | 136 | # A shorter title for the navigation bar. Default is the same as html_title. 137 | #html_short_title = None 138 | 139 | # The name of an image file (relative to this directory) to place at the top 140 | # of the sidebar. 141 | #html_logo = None 142 | 143 | # The name of an image file (within the static path) to use as favicon of the 144 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 145 | # pixels large. 146 | #html_favicon = None 147 | 148 | # Add any paths that contain custom static files (such as style sheets) here, 149 | # relative to this directory. They are copied after the builtin static files, 150 | # so a file named "default.css" will overwrite the builtin "default.css". 151 | html_static_path = ['_static'] 152 | 153 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 154 | # using the given strftime format. 155 | #html_last_updated_fmt = '%b %d, %Y' 156 | 157 | # If true, SmartyPants will be used to convert quotes and dashes to 158 | # typographically correct entities. 159 | #html_use_smartypants = True 160 | 161 | # Custom sidebar templates, maps document names to template names. 162 | #html_sidebars = {} 163 | 164 | # Additional templates that should be rendered to pages, maps page names to 165 | # template names. 166 | #html_additional_pages = {} 167 | 168 | # If false, no module index is generated. 169 | #html_domain_indices = True 170 | 171 | # If false, no index is generated. 172 | #html_use_index = True 173 | 174 | # If true, the index is split into individual pages for each letter. 175 | #html_split_index = False 176 | 177 | # If true, links to the reST sources are added to the pages. 178 | #html_show_sourcelink = True 179 | 180 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 181 | #html_show_sphinx = True 182 | 183 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 184 | #html_show_copyright = True 185 | 186 | # If true, an OpenSearch description file will be output, and all pages will 187 | # contain a tag referring to it. The value of this option must be the 188 | # base URL from which the finished HTML is served. 189 | #html_use_opensearch = '' 190 | 191 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 192 | #html_file_suffix = None 193 | 194 | # Output file base name for HTML help builder. 195 | htmlhelp_basename = 'ABXpydoc' 196 | 197 | 198 | # -- Options for LaTeX output -------------------------------------------- 199 | 200 | latex_elements = { 201 | # The paper size ('letterpaper' or 'a4paper'). 202 | #'papersize': 'letterpaper', 203 | 204 | # The font size ('10pt', '11pt' or '12pt'). 205 | #'pointsize': '10pt', 206 | 207 | # Additional stuff for the LaTeX preamble. 208 | #'preamble': '', 209 | } 210 | 211 | # Grouping the document tree into LaTeX files. List of tuples 212 | # (source start file, target name, title, author, documentclass [howto/manual]). 213 | latex_documents = [ 214 | ('index', 'ABXpy.tex', u'ABXpy Documentation', 215 | u'Author', 'manual'), 216 | ] 217 | 218 | # The name of an image file (relative to this directory) to place at the top of 219 | # the title page. 220 | #latex_logo = None 221 | 222 | # For "manual" documents, if this is true, then toplevel headings are parts, 223 | # not chapters. 224 | #latex_use_parts = False 225 | 226 | # If true, show page references after internal links. 227 | #latex_show_pagerefs = False 228 | 229 | # If true, show URL addresses after external links. 230 | #latex_show_urls = False 231 | 232 | # Documents to append as an appendix to all manuals. 233 | #latex_appendices = [] 234 | 235 | # If false, no module index is generated. 236 | #latex_domain_indices = True 237 | 238 | 239 | # -- Options for manual page output -------------------------------------- 240 | 241 | # One entry per manual page. List of tuples 242 | # (source start file, name, description, authors, manual section). 243 | man_pages = [ 244 | ('index', 'abxpy', u'ABXpy Documentation', 245 | [u'Author'], 1) 246 | ] 247 | 248 | # If true, show URL addresses after external links. 249 | #man_show_urls = False 250 | 251 | 252 | # -- Options for Texinfo output ------------------------------------------ 253 | 254 | # Grouping the document tree into Texinfo files. List of tuples 255 | # (source start file, target name, title, author, 256 | # dir menu entry, description, category) 257 | texinfo_documents = [ 258 | ('index', 'ABXpy', u'ABXpy Documentation', 259 | u'Author', 'ABXpy', 'One line description of project.', 260 | 'Miscellaneous'), 261 | ] 262 | 263 | # Documents to append as an appendix to all manuals. 264 | #texinfo_appendices = [] 265 | 266 | # If false, no module index is generated. 267 | #texinfo_domain_indices = True 268 | 269 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 270 | #texinfo_show_urls = 'footnote' 271 | 272 | 273 | # -- Options for Epub output --------------------------------------------- 274 | 275 | # Bibliographic Dublin Core info. 276 | epub_title = u'ABXpy' 277 | epub_author = u'Author' 278 | epub_publisher = u'Author' 279 | epub_copyright = u'2014, Author' 280 | 281 | # The language of the text. It defaults to the language option 282 | # or en if the language is not set. 283 | #epub_language = '' 284 | 285 | # The scheme of the identifier. Typical schemes are ISBN or URL. 286 | #epub_scheme = '' 287 | 288 | # The unique identifier of the text. This can be a ISBN number 289 | # or the project homepage. 290 | #epub_identifier = '' 291 | 292 | # A unique identification for the text. 293 | #epub_uid = '' 294 | 295 | # A tuple containing the cover image and cover page html template filenames. 296 | #epub_cover = () 297 | 298 | # HTML files that should be inserted before the pages created by sphinx. 299 | # The format is a list of tuples containing the path and title. 300 | #epub_pre_files = [] 301 | 302 | # HTML files shat should be inserted after the pages created by sphinx. 303 | # The format is a list of tuples containing the path and title. 304 | #epub_post_files = [] 305 | 306 | # A list of files that should not be packed into the epub file. 307 | #epub_exclude_files = [] 308 | 309 | # The depth of the table of contents in toc.ncx. 310 | #epub_tocdepth = 3 311 | 312 | # Allow duplicate toc entries. 313 | #epub_tocdup = True 314 | 315 | # Fixing conflict between numpydoc and autosummary 316 | numpydoc_show_class_members = False 317 | -------------------------------------------------------------------------------- /ABXpy/h5tools/h52np.py: -------------------------------------------------------------------------------- 1 | """Read efficiently h5 files 2 | 3 | Includes functions useful for merging sorted datasets. 4 | 5 | Some code is shared by H52NP and NP2H5: could have a superclass: 6 | optionally_h5_context_manager who would consist in implementing 7 | __init__, __enter__, __exit__ where a filename or a file handle can be 8 | passed and the file should be handled by the context manager only if a 9 | filename is passed. 10 | 11 | Also the functionalities specific to sorted datasets could be put in a 12 | subclass. 13 | 14 | """ 15 | 16 | import numpy as np 17 | import bisect 18 | import h5py 19 | 20 | 21 | class H52NP(object): 22 | # sink is the name of the HDF5 file to which to write, buffer size is in 23 | # kilobytes 24 | 25 | def __init__(self, h5file): 26 | # set up output file and buffer list 27 | if isinstance(h5file, str): 28 | self.manage_file = True 29 | self.filename = h5file 30 | self.file_open = False 31 | else: # supposed to be a h5 file handle 32 | self.manage_file = False 33 | self.file_open = True 34 | self.file = h5file 35 | self.filename = h5file.filename 36 | self.buffers = [] 37 | 38 | def __enter__(self): 39 | if not(self.file_open): 40 | self.file = h5py.File(self.filename, 'r+') 41 | self.file_open = True 42 | return self 43 | 44 | def __exit__(self, eType, eValue, eTrace): 45 | if self.file_open and self.manage_file: 46 | try: 47 | self.file.close() 48 | self.file_open = False 49 | except: 50 | if eValue is not None: 51 | # the first exception will be raised, but could log a 52 | # warning here ... 53 | pass 54 | else: 55 | raise 56 | 57 | def add_dataset(self, group, dataset, buf_size=100, 58 | minimum_occupied_portion=0.25): 59 | if self.file_open: 60 | buf = H52NPbuffer( 61 | self, group, dataset, buf_size, minimum_occupied_portion) 62 | self.buffers.append(buf) 63 | return buf 64 | else: 65 | raise IOError( 66 | "Method add_dataset of class H52NP can only be used " 67 | "within a 'with' statement!") 68 | 69 | def add_subdataset(self, group, dataset, buf_size=100, 70 | minimum_occupied_portion=0.25, indexes=None): 71 | if self.file_open: 72 | buf = H5dataset2NPbuffer( 73 | self, group, dataset, buf_size, minimum_occupied_portion, 74 | indexes=indexes) 75 | self.buffers.append(buf) 76 | return buf 77 | else: 78 | raise IOError( 79 | "Method add_dataset of class H52NP can only be used " 80 | "within a 'with' statement!") 81 | 82 | 83 | class H52NPbuffer(object): 84 | def __init__(self, parent, group, dataset, buf_size, 85 | minimum_occupied_portion): 86 | 87 | assert parent.file_open 88 | 89 | # get info from dataset 90 | # fail if dataset do not exist 91 | if not(group + '/' + dataset in parent.file): 92 | raise IOError('Dataset %s does not exists in file %s!' % 93 | (dataset, parent.filename)) 94 | dset = parent.file[group][dataset] 95 | self.n_rows = dset.shape[0] 96 | self.n_columns = dset.shape[1] 97 | self.type = dset.dtype 98 | self.dataset = dset 99 | self.dataset_ix = 0 100 | # could add checks: no more than 2 dims, etc. 101 | 102 | self.parent = parent 103 | self.minimum_occupied_portion = minimum_occupied_portion 104 | 105 | # initialize buffer 106 | row_size = self.n_columns * self.type.itemsize / \ 107 | 1000. # entry size in kilobytes 108 | self.buf_len = int(round(buf_size / row_size)) 109 | # buf_ix represents the number of free rows in the buffer. Here the 110 | # buffer is empty 111 | self.buf_ix = self.buf_len 112 | self.buf = np.zeros((self.buf_len, self.n_columns), self.type) 113 | # fill it 114 | self.refill_buffer() 115 | 116 | # read and consume, refill automatically if the buffer becomes empty, if 117 | # there is not enough data left, just send less than what was asked 118 | def read(self, amount=None): 119 | 120 | assert self.parent.file_open 121 | 122 | if not(amount is None) and amount <= 0: 123 | raise ValueError( 124 | 'The amount to read in h52np.read must be strictly positive') 125 | 126 | if amount is None: 127 | amount = self.buf_len - self.buf_ix 128 | 129 | if self.isempty(): 130 | raise StopIteration 131 | 132 | amount_found = 0 133 | data = [] 134 | while amount_found < amount and not(self.isempty()): 135 | needed = amount - amount_found 136 | amount_in_buffer = self.buf_len - self.buf_ix 137 | if amount_in_buffer > needed: # enough data in buffer 138 | next_buf_ix = self.buf_ix + needed 139 | amount_found = amount 140 | else: 141 | # not enough data in buffer (or just enough) 142 | next_buf_ix = self.buf_len 143 | amount_found = amount_found + amount_in_buffer 144 | 145 | # the np.copy is absolutely necessary here to avoid ugly 146 | # side effects... 147 | data.append(np.copy(self.buf[self.buf_ix:next_buf_ix, :])) 148 | self.buf_ix = next_buf_ix 149 | # fill buffer or not, according to refill policy and current buffer 150 | # state 151 | self.refill_buffer() 152 | return np.concatenate(data) 153 | 154 | def refill_buffer(self): 155 | if not(self.dataset_ix == self.n_rows): 156 | # for now one policy is implemented: if less than 157 | # self.minimum_occupied_portion of the full capacity is occupied 158 | # the buffer is refilled 159 | occupied_portion = 1. - float(self.buf_ix) / float(self.buf_len) 160 | if occupied_portion < self.minimum_occupied_portion: 161 | # set useful variables 162 | curr_ix = self.dataset_ix 163 | next_ix = curr_ix + self.buf_ix 164 | next_buf_ix = next_ix - self.n_rows 165 | amount_in_buffer = self.buf_len - self.buf_ix 166 | # take care of not going out of the dataset 167 | next_buf_ix = max(next_buf_ix, 0) 168 | next_ix = min(next_ix, self.n_rows) 169 | # move old data 170 | self.buf[next_buf_ix:next_buf_ix+amount_in_buffer, :] = self.buf[self.buf_ix:,:] 171 | # add new data 172 | self.buf[next_buf_ix+amount_in_buffer:, :] = self.dataset[curr_ix:next_ix,:] 173 | # update indices 174 | self.buf_ix = next_buf_ix 175 | self.dataset_ix = next_ix 176 | 177 | def __iter__(self): 178 | return self 179 | 180 | def __next__(self): 181 | return self.read() 182 | 183 | # true only if the input file has been totally read and the buffer is empty 184 | def isempty(self): 185 | assert self.parent.file_open 186 | return self.dataset_ix == self.n_rows and self.buf_ix == self.buf_len 187 | 188 | def buffer_empty(self): 189 | assert self.parent.file_open 190 | return self.buf_ix == self.buf_len 191 | 192 | def dataset_empty(self): 193 | assert self.parent.file_open 194 | return self.dataset_ix == self.n_rows 195 | 196 | # return the last row currently in the buffer (useful for merge sort...) 197 | # assuming the data is one-column 198 | def current_tail(self): 199 | assert self.parent.file_open 200 | assert self.n_columns == 1 201 | return self.buf[-1, 0] 202 | 203 | # returns the number of element in the buffer lower or equal to x, 204 | # assuming the data is one-column ordered and sorted 205 | def nb_lower_than(self, x): 206 | assert self.parent.file_open 207 | assert self.n_columns == 1 208 | return bisect.bisect_right(self.buf[self.buf_ix:, :], x) 209 | 210 | 211 | class H5dataset2NPbuffer(H52NPbuffer): 212 | """Augmentation of the H%2NPbuffer, proposing to use a subdataset 213 | selected by index""" 214 | def __init__(self, parent, group, dataset, buf_size, 215 | minimum_occupied_portion, indexes=None): 216 | assert parent.file_open 217 | 218 | # super(H5dataset2NPbuffer, self).__init__( 219 | # parent, group, dataset, buf_size, 220 | # minimum_occupied_portion) 221 | 222 | # get info from dataset 223 | # fail if dataset do not exist 224 | if not(group + '/' + dataset in parent.file): 225 | raise IOError('Dataset %s does not exists in file %s!' % 226 | (dataset, parent.filename)) 227 | dset = parent.file[group][dataset] 228 | self.n_columns = dset.shape[1] 229 | self.type = dset.dtype 230 | self.dataset = dset 231 | self.dataset_ix = 0 232 | self.dataset_end = self.dataset.shape[0] 233 | if indexes is not None: 234 | assert len(indexes) == 2 235 | self.dataset_ix = indexes[0] 236 | self.dataset_end = indexes[1] 237 | self.n_rows = self.dataset_end # - self.dataset_ix 238 | # could add checks: no more than 2 dims, etc. 239 | 240 | self.parent = parent 241 | self.minimum_occupied_portion = minimum_occupied_portion 242 | 243 | # initialize buffer 244 | row_size = self.n_columns * self.type.itemsize / \ 245 | 1000. # entry size in kilobytes 246 | self.buf_len = int(round(buf_size / row_size)) 247 | # buf_ix represents the number of free rows in the buffer. Here the 248 | # buffer is empty 249 | self.buf_ix = self.buf_len 250 | self.buf = np.zeros((self.buf_len, self.n_columns), self.type) 251 | # fill it 252 | self.refill_buffer() 253 | 254 | # read and consume, refill automatically if the buffer becomes empty, if 255 | # there is not enough data left, just send less than what was asked 256 | def read(self, amount=None): 257 | return super(H5dataset2NPbuffer, self).read(amount) 258 | 259 | def refill_buffer(self): 260 | if not(self.dataset_ix == self.dataset_end): 261 | # for now one policy is implemented: if less than 262 | # self.minimum_occupied_portion of the full capacity is occupied 263 | # the buffer is refilled 264 | occupied_portion = 1. - float(self.buf_ix) / float(self.buf_len) 265 | if occupied_portion < self.minimum_occupied_portion: 266 | # set useful variables 267 | curr_ix = self.dataset_ix 268 | next_ix = curr_ix + self.buf_ix 269 | next_buf_ix = next_ix - self.n_rows 270 | amount_in_buffer = self.buf_len - self.buf_ix 271 | # take care of not going out of the dataset 272 | next_buf_ix = max(next_buf_ix, 0) 273 | next_ix = min(next_ix, self.dataset_end) 274 | # move old data 275 | self.buf[next_buf_ix:next_buf_ix+amount_in_buffer, :] = self.buf[self.buf_ix:,:] 276 | # add new data 277 | self.buf[next_buf_ix+amount_in_buffer:, :] = self.dataset[curr_ix:next_ix,:] 278 | # update indices 279 | self.buf_ix = next_buf_ix 280 | self.dataset_ix = next_ix 281 | -------------------------------------------------------------------------------- /ABXpy/sampling/sampler.py: -------------------------------------------------------------------------------- 1 | """The sampler class implementing incremental sampling without replacement. 2 | 3 | Incremental meaning that you don't have to draw the whole sample at 4 | once, instead at any given time you can get a piece of the sample of a 5 | size you specify. This is useful for very large sample sizes. 6 | 7 | """ 8 | 9 | import numpy as np 10 | import math 11 | 12 | 13 | class IncrementalSampler(object): 14 | """Class for sampling without replacement in an incremental fashion 15 | 16 | Toy example of usage: 17 | 18 | sampler = IncrementalSampler(10**4, 10**4, step=100, \ 19 | relative_indexing=False) 20 | complete_sample = np.concatenate([sample for sample in sampler]) 21 | assert all(complete_sample==range(10**4)) 22 | 23 | More realistic example of usage: sampling without replacement 1 24 | million items from a total of 1 trillion items, considering 100 25 | millions items at a time 26 | 27 | sampler = IncrementalSampler(10**12, 10**6, step=10**8, \ 28 | relative_indexing=False) 29 | complete_sample = np.concatenate([sample for sample in sampler]) 30 | 31 | """ 32 | # sampling K sample in a a population of size N 33 | # both K and N can be very large 34 | def __init__(self, N, K, step=None, relative_indexing=True, 35 | dtype=np.int64): 36 | assert K <= N 37 | self.N = N # remaining items to sample from 38 | self.K = K # remaining items to be sampled 39 | self.initial_N = N 40 | self.relative_indexing = relative_indexing 41 | self.type = dtype # the type of the elements of the sample 42 | # step used when iterating over the sampler 43 | if step is None: 44 | # 10**4 samples by iteration on average 45 | self.step = 10 ** 4 * N // K 46 | else: 47 | self.step = step 48 | 49 | # method for implementing the iterable pattern 50 | def __iter__(self): 51 | return self 52 | 53 | # method for implementing the iterable pattern 54 | def next(self): 55 | if self.N == 0: 56 | raise StopIteration 57 | return self.sample(self.step) 58 | 59 | def sample(self, n, dtype=np.int64): 60 | """Fast implementation of the sampling function 61 | 62 | Get all samples from the next n items in a way that avoid rejection 63 | sampling with too large samples, more precisely samples whose expected 64 | number of sampled items is larger than 10**5. 65 | 66 | Parameters 67 | ---------- 68 | n : int 69 | the size of the chunk 70 | 71 | Returns 72 | ------- 73 | sample : numpy.array 74 | the indices to keep given relative to the current position 75 | in the sample or absolutely, depending on the value of 76 | relative_indexing specified when initialising the sampler 77 | (default value is True) 78 | 79 | """ 80 | self.type = dtype 81 | position = self.initial_N - self.N 82 | if n > self.N: 83 | n = self.N 84 | 85 | # expected number of sampled items 86 | expected_k = n * self.K / np.float(self.N) 87 | if expected_k > 10 ** 5: 88 | sample = [] 89 | chunk_size = int(np.floor(10 ** 5 * self.N / np.float(self.K))) 90 | i = 0 91 | while n > 0: 92 | amount = min(chunk_size, n) 93 | sample.append(self.simple_sample(amount) + i * chunk_size) 94 | n = n - amount 95 | i += 1 96 | sample = np.concatenate(sample) 97 | else: 98 | sample = self.simple_sample(n) 99 | 100 | if not self.relative_indexing: 101 | sample = sample + position 102 | 103 | return sample 104 | 105 | def simple_sample(self, n): 106 | """get all samples from the next n items in a naive fashion 107 | 108 | Parameters 109 | ---------- 110 | n : int 111 | the size of the chunk 112 | Returns 113 | ------- 114 | sample : numpy.array 115 | the indices to be kept relative to the current position 116 | in the sample 117 | """ 118 | k = hypergeometric_sample(self.N, self.K, n) # get the sample size 119 | sample = sample_without_replacement(k, n, self.type) 120 | self.N = self.N - n 121 | self.K = self.K - k 122 | return sample 123 | 124 | 125 | # function np.random.hypergeometric is buggy so I did my own 126 | # implementation... (error, at least, line 784 in computation of 127 | # variance: sample used instead of m, but this can't be all of it ?) 128 | # following algo HRUA by Ernst Stadlober as implemented in numpy 129 | # (https://github.com/numpy/numpy/blob/master/numpy/random/mtrand/ 130 | # distributions.c and see original ref in zotero) 131 | # this is 100 to 200 times slower than np.random.hypergeometric, but 132 | # it works reliably could be optimized a lot if needed (for small 133 | # samples in particular but also generally) 134 | # seems at worse to require comparable execution time when compared to the 135 | # actual rejection sampling, so probably not going to be so bad all in all 136 | def hypergeometric_sample(N, K, n): 137 | """This function return the number of elements to sample from the next n 138 | items. 139 | """ 140 | # handling edge cases 141 | if N == 0 or N == 1: 142 | k = K 143 | else: 144 | # using symmetries to speed up computations 145 | # if the probability of failure is smaller than the probability of 146 | # success, draw the failure count 147 | K_eff = min(K, N - K) 148 | # if the amount of items to sample from is larger than the amount of 149 | # items that will remain, draw from the items that will remain 150 | n_eff = min(n, N - n) 151 | N_float = np.float64(N) # useful to avoid unexpected roundings 152 | 153 | average = n_eff * (K_eff / N_float) 154 | mode = np.floor((n_eff + 1) * ((K_eff + 1) / (N_float + 2))) 155 | variance = average * ((N - K_eff) / N_float) * \ 156 | ((N - n_eff) / (N_float - 1)) 157 | c1 = 2 * np.sqrt(2 / np.e) 158 | c2 = 3 - 2 * np.sqrt(3 / np.e) 159 | a = average + 0.5 160 | b = c1 * np.sqrt(variance + 0.5) + c2 161 | p_mode = (math.lgamma(mode + 1) + math.lgamma(K_eff - mode + 1) + 162 | math.lgamma(n_eff - mode + 1) + 163 | math.lgamma(N - K_eff - n_eff + mode + 1)) 164 | # 16 for 16-decimal-digit precision in c1 and c2 (?) 165 | upper_bound = min( 166 | min(n_eff, K_eff) + 1, np.floor(a + 16 * np.sqrt(variance + 0.5))) 167 | 168 | while True: 169 | U = np.random.rand() 170 | V = np.random.rand() 171 | k = np.int64(np.floor(a + b * (V - 0.5) / U)) 172 | if k < 0 or k >= upper_bound: 173 | continue 174 | else: 175 | p_k = math.lgamma(k + 1) + math.lgamma(K_eff - k + 1) + \ 176 | math.lgamma(n_eff - k + 1) + \ 177 | math.lgamma(N - K_eff - n_eff + k + 1) 178 | d = p_mode - p_k 179 | if U * (4 - U) - 3 <= d: 180 | break 181 | if U * (U - d) >= 1: 182 | continue 183 | if 2 * np.log(U) <= d: 184 | break 185 | 186 | # retrieving original variables by symmetry 187 | if K_eff < K: 188 | k = n_eff - k 189 | if n_eff < n: 190 | k = K - k 191 | 192 | return k 193 | 194 | 195 | # returns uniform samples in [0, N-1] without replacement the values 196 | # 0.6 and 100 are based on empirical tests of the functions and would 197 | # need to be changed if the functions are changed 198 | def sample_without_replacement(n, N, dtype=np.int64): 199 | """Returns uniform samples in [0, N-1] without replacement. It will use 200 | Knuth sampling or rejection sampling depending on the parameters n and N. 201 | 202 | .. note:: 203 | 204 | the values 0.6 and 100 are based on empirical tests of the 205 | functions and would need to be changed if the functions are 206 | changed 207 | 208 | """ 209 | if N > 100 and n / float(N) < 0.6: 210 | sample = rejection_sampling(n, N, dtype) 211 | else: 212 | sample = Knuth_sampling(n, N, dtype) 213 | return sample 214 | 215 | 216 | # this one would benefit a lot from being cythonized, efficient if n 217 | # close to N (np.random.choice with replace=False is cythonized and 218 | # similar in spirit but not better because it shuffles the whole array 219 | # of size N which is wasteful; once cythonized Knuth_sampling should 220 | # be superior to it in all situation) 221 | def Knuth_sampling(n, N, dtype=np.int64): 222 | """This is the usual sampling function when n is comparable to N""" 223 | n = int(n) 224 | 225 | t = 0 # total input records dealt with 226 | m = 0 # number of items selected so far 227 | sample = np.zeros(shape=n, dtype=dtype) 228 | while m < n: 229 | u = np.random.rand() 230 | if (N - t) * u < n - m: 231 | sample[m] = t 232 | m = m + 1 233 | t = t + 1 234 | return sample 235 | 236 | 237 | # maybe use array for the first iteration then use python native sets 238 | # for faster set operations ? 239 | def rejection_sampling(n, N, dtype=np.int64): 240 | """Using rejection sampling to keep a good performance if n << N""" 241 | remaining = n 242 | sample = np.array([], dtype=dtype) 243 | while remaining > 0: 244 | new_sample = np.random.randint(0, int(N), int(remaining)).astype(dtype) 245 | # keeping only unique element: 246 | sample = np.union1d(sample, np.unique(new_sample)) 247 | remaining = n - sample.shape[0] 248 | return sample 249 | 250 | 251 | # Profiling hypergeometric sampling + sampling without replacement together: 252 | 253 | # ChunkSize 254 | # 10**2: 255 | # hyper: 46s 256 | # sample: 32s 257 | # 10**3: 258 | # hyper : 4.5s 259 | # sample 6s 260 | # 10**4: 261 | # hyper 0.5s 262 | # sample 4.5s 263 | # 10**5: 264 | # hyper 0.05s 265 | # sample 4s 266 | # 10**6: 267 | # hyper 0.007s 268 | # sample 5.7s 269 | # 10**7: 270 | # hyper 0.001s 271 | # sample 10.33s 272 | # + memory increase with chunk size 273 | 274 | # Should aim at having samples with around 100 000 elements. 275 | # This means sampling in 10**5 * sampled_proportion chunks. 276 | 277 | # profiling code: 278 | 279 | # import time 280 | # tt=[] 281 | # ra = range(8, 9) 282 | # for block_size in ra: 283 | # t = time.clock() 284 | # progress = 0 285 | # b = 10**block_size 286 | # for i in range(10**12//(10**3*b)): 287 | # r = s.sample(10**3*b) 288 | # progress = progress+100*(len(r)/10.**9) 289 | # print(progress) 290 | # if progress > 3: 291 | # break 292 | # tt.append(time.clock()-t) 293 | # for e, b in zip(tt, ra): 294 | # print(b) 295 | # print(e) 296 | 297 | # ## Profiling rejection sampling and Knuth sampling ### 298 | # could create an automatic test for finding the turning point and offset 299 | # between Knuth and rejection 300 | 301 | # manual results: 302 | # N 100 303 | # n 304 | # 1 R:60mu, K:30mu 305 | # 10 R:83mu, K:54mu 306 | # 100 R:7780mu, K:78mu 307 | # N < 100 always Knuth 308 | # 309 | # N 1000 310 | # n 311 | # 1 R:60mu, K:248mu 312 | # 10 R:65mu, K:450mu 313 | # 100 R:150mu, K:523mu 314 | # 1000 R:8610mu, K:785mu 315 | # turning point: n/N between 0.5 and 0.75 316 | # 317 | # N 10**6 318 | # 10**6 R:???, K:791ms 319 | # 10**4 R:1ms, K:562ms 320 | # 10**5 R:20ms, K:531ms 321 | # turning point: n/N between 0.5 and 0.75 322 | # 323 | # N 10**9 324 | # 10**6 R:174ms 325 | # 10**7 R:2.7s 326 | # 327 | # N 10**18 328 | # 1 R: 62mu 329 | # 10 R: 62mu 330 | # 10**3 R: 148mu 331 | # 10**6 R: 131ms 332 | --------------------------------------------------------------------------------