├── chainsaw ├── _ext │ ├── __init__.py │ └── variational_estimators │ │ ├── tests │ │ ├── __init__.py │ │ └── benchmark_moments.py │ │ ├── covar_c │ │ ├── __init__.py │ │ └── _covartools.h │ │ ├── __init__.py │ │ └── lagged_correlation.py ├── base │ ├── __init__.py │ ├── reporter.py │ ├── estimator.py │ ├── loggable.py │ └── model.py ├── data │ ├── _base │ │ ├── __init__.py │ │ └── random_accessible.py │ ├── md │ │ ├── featurization │ │ │ ├── __init__.py │ │ │ └── _base.py │ │ └── __init__.py │ ├── util │ │ ├── __init__.py │ │ └── fileformat_registry.py │ ├── __init__.py │ └── numpy_filereader.py ├── tests │ ├── data │ │ ├── bpti_mini.dcd │ │ ├── bpti_mini.h5 │ │ ├── bpti_mini.lh5 │ │ ├── bpti_mini.nc │ │ ├── bpti_mini.trr │ │ ├── bpti_mini.xtc │ │ ├── bpti_001-033.xtc │ │ ├── bpti_034-066.xtc │ │ ├── bpti_067-100.xtc │ │ ├── bpti_mini.binpos │ │ ├── bpti_mini.netcdf │ │ ├── opsin_Ca_1_frame.pdb.gz │ │ ├── opsin_aa_1_frame.pdb.gz │ │ ├── test.pdb │ │ └── bpti_ca.pdb │ ├── __init__.py │ ├── test_format_registry.py │ ├── test_acf.py │ ├── test_uniform_time.py │ ├── test_mini_batch_kmeans.py │ ├── util.py │ ├── test_cluster_samples.py │ ├── test_api_load.py │ ├── test_cache.py │ ├── test_featurereader_and_tica.py │ ├── test_stride.py │ ├── test_regspace.py │ ├── test_discretizer.py │ ├── test_coordinates_iterator.py │ ├── test_featurereader_and_tica_projection.py │ ├── test_cluster.py │ └── test_pca.py ├── util │ ├── __init__.py │ ├── contexts.py │ ├── exceptions.py │ ├── change_notification.py │ ├── files.py │ ├── indices.py │ ├── stat.py │ ├── log.py │ ├── units.py │ ├── reflection.py │ └── linalg.py ├── _resources │ ├── logging.yml │ └── chainsaw.cfg ├── transform │ └── __init__.py ├── clustering │ ├── __init__.py │ ├── include │ │ └── clustering.h │ ├── assign.py │ ├── uniform_time.py │ └── regspace.py ├── __init__.py └── acf.py ├── .gitattributes ├── README.rst ├── devtools ├── conda-recipe │ ├── build.sh │ ├── build.sh.orig │ ├── bld.bat │ ├── meta.yaml │ └── run_test.py └── ci │ ├── appveyor │ ├── after_success.bat │ ├── deploy.ps1 │ ├── runTestsuite.ps1 │ ├── run_with_env.cmd │ └── transform_xunit_to_appveyor.xsl │ ├── travis │ ├── install_miniconda.sh │ ├── make_docs.sh │ ├── after_success.sh │ └── dev_pkgs_del_old.py │ └── jenkins │ └── update_versions_json.py ├── .gitignore ├── MANIFEST.in ├── setup.cfg ├── .travis.yml └── setup_util.py /chainsaw/_ext/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chainsaw/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chainsaw/data/_base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | chainsaw/_version.py export-subst 2 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | chainsaw 2 | ======== 3 | 4 | desc 5 | -------------------------------------------------------------------------------- /chainsaw/_ext/variational_estimators/tests/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'noe' 2 | -------------------------------------------------------------------------------- /devtools/conda-recipe/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | $PYTHON setup.py install 3 | -------------------------------------------------------------------------------- /chainsaw/_ext/variational_estimators/covar_c/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'noe' 2 | -------------------------------------------------------------------------------- /chainsaw/data/md/featurization/__init__.py: -------------------------------------------------------------------------------- 1 | from .featurizer import MDFeaturizer, CustomFeature 2 | -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.dcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.dcd -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.h5 -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.lh5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.lh5 -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.nc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.nc -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.trr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.trr -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.xtc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.xtc -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_001-033.xtc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_001-033.xtc -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_034-066.xtc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_034-066.xtc -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_067-100.xtc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_067-100.xtc -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.binpos: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.binpos -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_mini.netcdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.netcdf -------------------------------------------------------------------------------- /chainsaw/tests/data/opsin_Ca_1_frame.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/opsin_Ca_1_frame.pdb.gz -------------------------------------------------------------------------------- /chainsaw/tests/data/opsin_aa_1_frame.pdb.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/opsin_aa_1_frame.pdb.gz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | build 3 | 4 | *.pyc 5 | *.so 6 | *.egg-info/ 7 | *.orig 8 | 9 | */_ext/variational_estimators/covar_c/covartools.c 10 | 11 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include versioneer.py 2 | include chainsaw/_version.py 3 | 4 | recursive-include chainsaw/_resources * 5 | recursive-include chainsaw/tests/data * 6 | -------------------------------------------------------------------------------- /chainsaw/tests/data/test.pdb: -------------------------------------------------------------------------------- 1 | ATOM 3 CH3 ACE A 1 28.490 31.600 33.379 0.00 1.00 2 | ATOM 7 C ACE A 2 27.760 30.640 34.299 0.00 1.00 3 | ATOM 8 C ACE A 2 27.760 30.640 34.299 0.00 1.00 4 | 5 | -------------------------------------------------------------------------------- /chainsaw/data/md/__init__.py: -------------------------------------------------------------------------------- 1 | import pkgutil 2 | 3 | loader = pkgutil.find_loader('mdtraj') 4 | 5 | if loader is not None: 6 | from .feature_reader import FeatureReader 7 | from .featurization.featurizer import MDFeaturizer, CustomFeature 8 | -------------------------------------------------------------------------------- /chainsaw/_ext/variational_estimators/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .moments import moments_XX, moments_XXXY, moments_block 4 | from .moments import covar, covars 5 | from .running_moments import RunningCovar, running_covar 6 | -------------------------------------------------------------------------------- /devtools/conda-recipe/build.sh.orig: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | $PYTHON setup.py install 3 | <<<<<<< HEAD 4 | version=$($PYTHON -c "from __future__ import print_function; import pyemma; print(pyemma.__version__)") 5 | export PYEMMA_VERSION=$version 6 | ======= 7 | >>>>>>> devel 8 | -------------------------------------------------------------------------------- /devtools/ci/appveyor/after_success.bat: -------------------------------------------------------------------------------- 1 | conda install --yes -q anaconda-client jinja2 2 | cd %PYTHON_MINICONDA%\conda-bld 3 | dir /s /b %PACKAGENAME%-dev-*.tar.bz2 > files.txt 4 | for /F %%filename in (files.txt) do ( 5 | echo "uploading file %%~filename" 6 | anaconda -t %BINSTAR_TOKEN% upload --force -u %ORGNAME% -p %PACKAGENAME%-dev %%~filename 7 | ) 8 | -------------------------------------------------------------------------------- /devtools/conda-recipe/bld.bat: -------------------------------------------------------------------------------- 1 | if not defined APPVEYOR ( 2 | echo not on appveyor 3 | "%PYTHON%" setup.py install 4 | ) else ( 5 | echo on appveyor 6 | cmd /E:ON /V:ON /C %APPVEYOR_BUILD_FOLDER%\devtools\ci\appveyor\run_with_env.cmd "%PYTHON%" setup.py install 7 | ) 8 | set build_status=%ERRORLEVEL% 9 | if %build_status% == 1 exit 1 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | # See the docstring in versioneer.py for instructions. Note that you must 3 | # re-run 'versioneer.py setup' after changing this section, and commit the 4 | # resulting files. 5 | 6 | [versioneer] 7 | VCS = git 8 | style = pep440 9 | versionfile_source = chainsaw/_version.py 10 | versionfile_build = chainsaw/_version.py 11 | tag_prefix = v 12 | #parentdir_prefix = 13 | 14 | -------------------------------------------------------------------------------- /devtools/ci/appveyor/deploy.ps1: -------------------------------------------------------------------------------- 1 | function deploy() { 2 | # install tools 3 | pip install wheel twine 4 | 5 | # create wheel and win installer 6 | python setup.py bdist_wheel bdist_wininst 7 | 8 | # upload to pypi with twine 9 | twine upload -i $env:myuser -p $env:mypass dist/* 10 | } 11 | 12 | new_tag = ($env:APPVEYOR_REPO_TAG -eq true) 13 | new_tag = true # temporarily enable for all commits 14 | 15 | if (new_tag) { 16 | deploy 17 | } -------------------------------------------------------------------------------- /chainsaw/base/reporter.py: -------------------------------------------------------------------------------- 1 | from progress_reporter import ProgressReporter as _impl 2 | 3 | 4 | class ProgressReporter(_impl): 5 | @_impl.show_progress.getter 6 | def show_progress(self): 7 | """ whether to show the progress of heavy calculations on this object. """ 8 | if not hasattr(self, "_show_progress"): 9 | from chainsaw import config 10 | val = config.show_progress_bars 11 | self._show_progress = val 12 | return self._show_progress 13 | -------------------------------------------------------------------------------- /devtools/ci/travis/install_miniconda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # make TARGET overrideable with env 4 | : ${TARGET:=$HOME/miniconda} 5 | 6 | function install_miniconda { 7 | if [ -d $TARGET ]; then echo "file exists"; return; fi 8 | echo "installing miniconda to $TARGET" 9 | if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then 10 | platform="Linux" 11 | elif [[ "$TRAVIS_OS_NAME" == "osx" ]]; then 12 | platform="MacOSX" 13 | fi 14 | wget http://repo.continuum.io/miniconda/Miniconda3-latest-$platform-x86_64.sh -O mc.sh -o /dev/null 15 | bash mc.sh -b -f -p $TARGET 16 | } 17 | 18 | install_miniconda 19 | export PATH=$TARGET/bin:$PATH 20 | -------------------------------------------------------------------------------- /chainsaw/util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | -------------------------------------------------------------------------------- /chainsaw/data/util/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | -------------------------------------------------------------------------------- /chainsaw/tests/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import absolute_import 20 | -------------------------------------------------------------------------------- /chainsaw/_ext/variational_estimators/covar_c/_covartools.h: -------------------------------------------------------------------------------- 1 | void _subtract_row_double(double* X, double* row, int M, int N); 2 | void _subtract_row_float(double* X, double* row, int M, int N); 3 | void _subtract_row_double_copy(double* X0, double* X, double* row, int M, int N); 4 | int* _bool_to_list(int* b, int N, int nnz); 5 | void _variable_cols_char(int* cols, char* X, int M, int N, int min_constant); 6 | void _variable_cols_int(int* cols, int* X, int M, int N, int min_constant); 7 | void _variable_cols_long(int* cols, long* X, int M, int N, int min_constant); 8 | void _variable_cols_float(int* cols, float* X, int M, int N, int min_constant); 9 | void _variable_cols_double(int* cols, double* X, int M, int N, int min_constant); 10 | void _variable_cols_float_approx(int* cols, float* X, int M, int N, float tol, int min_constant); 11 | void _variable_cols_double_approx(int* cols, double* X, int M, int N, double tol, int min_constant); 12 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: c 2 | git: 3 | submodules: false 4 | os: 5 | - osx 6 | - linux 7 | 8 | sudo: false 9 | 10 | env: 11 | global: 12 | - PATH=$HOME/miniconda/bin:$PATH 13 | - common_py_deps="jinja2 conda-build=1.21.3" 14 | - PACKAGENAME=chainsaw 15 | - ORGNAME=omnia 16 | - DEV_BUILD_N_KEEP=10 17 | matrix: 18 | - CONDA_PY=2.7 CONDA_NPY=1.11 19 | - CONDA_PY=3.4 CONDA_NPY=1.10 20 | - CONDA_PY=3.5 CONDA_NPY=1.11 21 | 22 | before_install: 23 | - devtools/ci/travis/install_miniconda.sh 24 | - conda config --set always_yes true 25 | - conda config --add channels omnia 26 | - conda install -q $common_py_deps 27 | 28 | script: 29 | - conda build -q devtools/conda-recipe 30 | 31 | after_success: 32 | # coverage report: needs .coverage file generated by testsuite and git src 33 | - pip install coveralls 34 | - coveralls 35 | #- if [ "$TRAVIS_SECURE_ENV_VARS" == true ]; then source devtools/ci/travis/after_success.sh; fi 36 | 37 | -------------------------------------------------------------------------------- /devtools/ci/travis/make_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | function install_deps { 3 | wget https://github.com/jgm/pandoc/releases/download/1.13.2/pandoc-1.13.2-1-amd64.deb \ 4 | -O pandoc.deb 5 | dpkg -x pandoc.deb $HOME 6 | 7 | export PATH=$PATH:$HOME/usr/bin 8 | # try to execute pandoc 9 | pandoc --version 10 | 11 | conda install -q --yes $doc_deps 12 | pip install -r requirements-build-doc.txt wheel 13 | } 14 | 15 | function build_doc { 16 | pushd doc; make ipython-rst html 17 | # workaround for docs dir => move doc to build/docs afterwards 18 | # travis (currently )expects docs in build/docs (should contain index.html?) 19 | mv build/html ../build/docs 20 | popd 21 | } 22 | 23 | function deploy_doc { 24 | echo "[distutils] 25 | index-servers = pypi 26 | 27 | [pypi] 28 | username:marscher 29 | password:${pypi_pass}" > ~/.pypirc 30 | 31 | python setup.py upload_docs 32 | } 33 | 34 | # build docs only for python 2.7 and for normal commits (not pull requests) 35 | if [[ $TRAVIS_PYTHON_VERSION = "2.7" ]] && [[ "${TRAVIS_PULL_REQUEST}" = "false" ]]; then 36 | install_deps 37 | build_doc 38 | deploy_doc 39 | fi 40 | -------------------------------------------------------------------------------- /chainsaw/_resources/logging.yml: -------------------------------------------------------------------------------- 1 | # Chainsaw's default logging settings 2 | # If you want to enable file logging, uncomment the file related handlers and handlers 3 | # 4 | 5 | # do not disable other loggers by default. 6 | disable_existing_loggers: False 7 | 8 | # please do not change version, it is an internal variable used by Python. 9 | version: 1 10 | 11 | formatters: 12 | simpleFormater: 13 | format: '%(asctime)s %(name)-12s %(levelname)-8s %(message)s' 14 | datefmt: '%d-%m-%y %H:%M:%S' 15 | 16 | handlers: 17 | # log to stdout 18 | console: 19 | class: logging.StreamHandler 20 | formatter: simpleFormater 21 | stream: ext://sys.stdout 22 | # example for rotating log files, disabled by default 23 | #rotating_files: 24 | # class: logging.handlers.RotatingFileHandler 25 | # formatter: simpleFormater 26 | # filename: chainsaw.log 27 | # maxBytes: 1048576 # 1 MB 28 | # backupCount: 3 29 | 30 | loggers: 31 | chainsaw: 32 | level: INFO 33 | # by default no file logging! 34 | handlers: [console] #, rotating_files] 35 | 36 | -------------------------------------------------------------------------------- /chainsaw/_resources/chainsaw.cfg: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Chainsaw configuration file 3 | # 4 | # notes: 5 | # - comments are not allowed in line, since they would be appended to the value! 6 | ################################################################################ 7 | 8 | [chainsaw] 9 | # configuration notice shown? 10 | show_config_notification = False 11 | 12 | # Source to logging configuration file (YAML). 13 | # Special value: DEFAULT (use default config). 14 | # If this is set to a filename, it will be red to configure logging. If it is a 15 | # relative path, it is assumed to be located next to where you start your interpreter. 16 | logging_config = DEFAULT 17 | 18 | # show or hide progress bars globally? 19 | show_progress_bars = True 20 | 21 | # useful for trajectory formats, for which one has to read the whole file to get len 22 | # eg. XTC format. 23 | use_trajectory_lengths_cache = True 24 | # maximum entries in database 25 | traj_info_max_entries = 50000 26 | # max size in MB 27 | traj_info_max_size = 500 28 | 29 | # Cache directory, defaults to the operating systems temporary directory, if set to None 30 | cache_dir = None 31 | -------------------------------------------------------------------------------- /devtools/ci/appveyor/runTestsuite.ps1: -------------------------------------------------------------------------------- 1 | function xslt_transform($xml, $xsl, $output) 2 | { 3 | trap [Exception] 4 | { 5 | Write-Host $_.Exception 6 | } 7 | 8 | $xslt = New-Object System.Xml.Xsl.XslCompiledTransform 9 | $xslt.Load($xsl) 10 | $xslt.Transform($xml, $output) 11 | } 12 | 13 | function upload($file) { 14 | trap [Exception] 15 | { 16 | Write-Host $_.Exception 17 | } 18 | 19 | $wc = New-Object 'System.Net.WebClient' 20 | $wc.UploadFile("https://ci.appveyor.com/api/testresults/xunit/$($env:APPVEYOR_JOB_ID)", $file) 21 | } 22 | 23 | function run { 24 | cd $env:APPVEYOR_BUILD_FOLDER 25 | $stylesheet = "devtools/ci/appveyor/transform_xunit_to_appveyor.xsl" 26 | $input = "nosetests.xml" 27 | $output = "transformed.xml" 28 | 29 | nosetests pyemma --all-modules --with-xunit -a '!slow' 30 | $success = $? 31 | Write-Host "result code of nosetests:" $success 32 | 33 | xslt_transform $input $stylesheet $output 34 | 35 | upload $output 36 | Push-AppveyorArtifact $input 37 | Push-AppveyorArtifact $output 38 | 39 | # return exit code of testsuite 40 | if ( -not $success) { 41 | throw "testsuite not successful" 42 | } 43 | } 44 | 45 | run 46 | -------------------------------------------------------------------------------- /devtools/conda-recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: chainsaw 3 | # version number: [base tag]+[commits-upstream]_[git_hash] 4 | # eg. v2.0+0_g8824162 5 | version: {{ GIT_DESCRIBE_TAG[1:] + '+' +GIT_BUILD_STR}} 6 | source: 7 | path: ../.. 8 | 9 | build: 10 | preserve_egg_dir: True 11 | 12 | requirements: 13 | build: 14 | - python 15 | - setuptools 16 | - cython >=0.20 17 | - mock 18 | - mdtraj # TODO: make it optional? 19 | - funcsigs 20 | - numpy x.x 21 | - h5py 22 | - six 23 | - psutil >=3.1.1 24 | - decorator >=4.0.0 25 | - progress_reporter 26 | - pyyaml 27 | - nomkl 28 | - msmtools # TODO: remove 29 | 30 | run: 31 | - python 32 | - setuptools 33 | - mock 34 | - mdtraj # TODO: make it optional? 35 | - funcsigs 36 | - numpy x.x 37 | - h5py 38 | - six 39 | - psutil >=3.1.1 40 | - decorator >=4.0.0 41 | - progress_reporter 42 | - pyyaml 43 | - nomkl 44 | - msmtools # TODO: remove 45 | 46 | test: 47 | requires: 48 | - nose 49 | - coverage ==4 50 | imports: 51 | - chainsaw 52 | 53 | about: 54 | home: http://emma-project.org 55 | license: GNU Lesser Public License v3+ 56 | summary: "data pipe" 57 | 58 | 59 | -------------------------------------------------------------------------------- /devtools/ci/travis/after_success.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # The pull request number if the current job is a pull request, “false” if it’s not a pull request. 4 | if [[ ! "$TRAVIS_PULL_REQUEST" == "false" ]]; then 5 | echo "This is a pull request. No deployment will be done."; exit 0 6 | fi 7 | 8 | # For builds not triggered by a pull request this is the name of the branch currently being built; 9 | # whereas for builds triggered by a pull request this is the name of the branch targeted by the pull request (in many cases this will be master). 10 | if [ "$TRAVIS_BRANCH" != "devel" ]; then 11 | echo "No deployment on BRANCH='$TRAVIS_BRANCH'"; exit 0 12 | fi 13 | 14 | 15 | # Deploy to binstar 16 | conda install --yes anaconda-client jinja2 17 | pushd . 18 | cd $HOME/miniconda/conda-bld 19 | FILES=*/${PACKAGENAME}-dev-*.tar.bz2 20 | for filename in $FILES; do 21 | anaconda -t $BINSTAR_TOKEN upload --force -u ${ORGNAME} -p ${PACKAGENAME}-dev ${filename} 22 | done 23 | popd 24 | 25 | # call cleanup only for py35, numpy111 26 | if [[ "$CONDA_PY" == "3.5" && "$CONDA_NPY" == "111" && "$TRAVIS_OS_NAME" == "linux" ]]; then 27 | python devtools/ci/travis/dev_pkgs_del_old.py 28 | else 29 | echo "only executing cleanup script for py35 && npy111 && linux" 30 | fi 31 | 32 | -------------------------------------------------------------------------------- /chainsaw/tests/test_format_registry.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from chainsaw.data.util.fileformat_registry import FileFormatRegistry, FileFormatUnsupported 4 | from chainsaw.data._base.datasource import DataSource 5 | 6 | 7 | class TestFormatRegistry(unittest.TestCase): 8 | def test_multiple_calls_register(self): 9 | with self.assertRaises(RuntimeError) as exc: 10 | @FileFormatRegistry.register(".foo") 11 | @FileFormatRegistry.register(".bar") 12 | class test_src(): 13 | pass 14 | self.assertIn("only once", exc.exception.args[0]) 15 | 16 | def test_correct_reader_by_ext(self): 17 | 18 | @FileFormatRegistry.register(".foo") 19 | class test_src_foo(DataSource): 20 | pass 21 | 22 | @FileFormatRegistry.register(".bar") 23 | class test_src_bar(DataSource): 24 | pass 25 | 26 | self.assertEqual(FileFormatRegistry[".foo"], test_src_foo) 27 | self.assertEqual(FileFormatRegistry[".bar"], test_src_bar) 28 | 29 | def test_not_supported(self): 30 | ext = '.imsurethisisnotvalid' 31 | with self.assertRaises(FileFormatUnsupported) as exc: 32 | FileFormatRegistry[ext] 33 | self.assertIn(ext, exc.exception.args[0]) 34 | self.assertIn('not supported', exc.exception.args[0]) 35 | -------------------------------------------------------------------------------- /chainsaw/transform/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | r""" 21 | =============================================================================== 22 | transform - Transformation Utilities (:mod:`chainsaw.transform`) 23 | =============================================================================== 24 | .. currentmodule:: chainsaw.transform 25 | 26 | .. autosummary:: 27 | :toctree: generated/ 28 | 29 | PCA - principal components 30 | TICA - time independent components 31 | """ 32 | 33 | from .pca import * 34 | from .tica import * 35 | -------------------------------------------------------------------------------- /chainsaw/data/md/featurization/_base.py: -------------------------------------------------------------------------------- 1 | # This file is part of PyEMMA. 2 | # 3 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 4 | # 5 | # PyEMMA is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU Lesser General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU Lesser General Public License 16 | # along with this program. If not, see . 17 | ''' 18 | Created on 15.02.2016 19 | 20 | @author: marscher 21 | ''' 22 | from chainsaw.util.annotators import deprecated 23 | 24 | 25 | class Feature(object): 26 | 27 | @property 28 | def dimension(self): 29 | return self._dim 30 | 31 | @dimension.setter 32 | def dimension(self, val): 33 | assert isinstance(val, int) 34 | self._dim = val 35 | 36 | @deprecated('use transform(traj)') 37 | def map(self, traj): 38 | return self.transform(traj) 39 | 40 | def __eq__(self, other): 41 | return self.__hash__() == other.__hash__() 42 | -------------------------------------------------------------------------------- /devtools/conda-recipe/run_test.py: -------------------------------------------------------------------------------- 1 | 2 | import subprocess 3 | import os 4 | import sys 5 | import shutil 6 | import re 7 | 8 | src_dir = os.getenv('SRC_DIR') 9 | 10 | test_pkg = 'chainsaw' 11 | cover_pkg = test_pkg 12 | 13 | # matplotlib headless backend 14 | with open('matplotlibrc', 'w') as fh: 15 | fh.write('backend: Agg') 16 | 17 | 18 | def coverage_report(): 19 | fn = '.coverage' 20 | assert os.path.exists(fn) 21 | build_dir = os.getenv('TRAVIS_BUILD_DIR') 22 | dest = os.path.join(build_dir, fn) 23 | print( "copying coverage report to", dest) 24 | shutil.copy(fn, dest) 25 | assert os.path.exists(dest) 26 | 27 | # fix paths in .coverage file 28 | with open(dest, 'r') as fh: 29 | data = fh.read() 30 | match= '"/.+?/miniconda/envs/_test/lib/python.+?/site-packages/.+?/({test_pkg}/.+?)"'.format(test_pkg=test_pkg) 31 | repl = '"%s/\\1"' % build_dir 32 | data = re.sub(match, repl, data) 33 | os.unlink(dest) 34 | with open(dest, 'w+') as fh: 35 | fh.write(data) 36 | 37 | nose_run = "nosetests {test_pkg} -vv" \ 38 | " --with-coverage --cover-inclusive --cover-package={cover_pkg}" \ 39 | " --with-doctest --doctest-options=+NORMALIZE_WHITESPACE,+ELLIPSIS" \ 40 | .format(test_pkg=test_pkg, cover_pkg=cover_pkg).split(' ') 41 | 42 | res = subprocess.call(nose_run) 43 | 44 | 45 | # move .coverage file to git clone on Travis CI 46 | if os.getenv('TRAVIS', False): 47 | coverage_report() 48 | 49 | sys.exit(res) 50 | 51 | -------------------------------------------------------------------------------- /chainsaw/data/util/fileformat_registry.py: -------------------------------------------------------------------------------- 1 | 2 | class FileFormatUnsupported(Exception): 3 | """ No available reader is able to handle the given extension.""" 4 | 5 | 6 | class FileFormatRegistry(object): 7 | """ Registry for trajectory file objects. """ 8 | _readers = {} 9 | 10 | @classmethod 11 | def register(cls, *args): 12 | """ register the given class as Reader class for given extension(s). """ 13 | def decorator(f): 14 | if hasattr(f, 'SUPPORTED_EXTENSIONS') and f.SUPPORTED_EXTENSIONS is not (): 15 | raise RuntimeError("please call register() only once per class!") 16 | extensions = tuple(args) 17 | cls._readers.update({e: f for e in extensions}) 18 | f.SUPPORTED_EXTENSIONS = extensions 19 | return f 20 | return decorator 21 | 22 | @staticmethod 23 | def is_md_format(extension): 24 | import chainsaw.data.md as md 25 | if hasattr(md, 'FeatureReader'): 26 | from chainsaw.data.md import FeatureReader 27 | return extension in FeatureReader.SUPPORTED_EXTENSIONS 28 | 29 | return False 30 | 31 | def supported_extensions(self): 32 | return self._readers.keys() 33 | 34 | def __getitem__(self, item): 35 | try: 36 | return self._readers[item] 37 | except KeyError: 38 | raise FileFormatUnsupported("Extension {ext} is not supported by any reader.".format(ext=item)) 39 | 40 | # singleton pattern 41 | FileFormatRegistry = FileFormatRegistry() 42 | -------------------------------------------------------------------------------- /chainsaw/clustering/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | r""" 21 | =============================================================================== 22 | clustering - Algorithms (:mod:`chainsaw.clustering`) 23 | =============================================================================== 24 | 25 | .. currentmodule: chainsaw.clustering 26 | 27 | .. autosummary:: 28 | :toctree: generated/ 29 | 30 | AssignCenters 31 | KmeansClustering 32 | RegularSpaceClustering 33 | UniformTimeClustering 34 | """ 35 | 36 | from .assign import AssignCenters 37 | from .kmeans import KmeansClustering 38 | from .kmeans import MiniBatchKmeansClustering 39 | from .regspace import RegularSpaceClustering 40 | from .uniform_time import UniformTimeClustering -------------------------------------------------------------------------------- /devtools/ci/travis/dev_pkgs_del_old.py: -------------------------------------------------------------------------------- 1 | """ 2 | Cleanup old development builds on Anaconda.org 3 | 4 | Assumes one has set 4 environment variables: 5 | 6 | 1. BINSTAR_TOKEN: token to authenticate with anaconda.org 7 | 2. DEV_BUILD_N_KEEP: int, how many builds to keep, delete oldest first. 8 | 3. ORGNAME: str, anaconda.org organisation/user 9 | 4. PACKGENAME: str, name of package to clean up 10 | 11 | author: Martin K. Scherer 12 | data: 20.4.16 13 | """ 14 | from __future__ import print_function, absolute_import 15 | 16 | import os 17 | 18 | from binstar_client.utils import get_server_api 19 | from pkg_resources import parse_version 20 | 21 | token = os.environ['BINSTAR_TOKEN'] 22 | org = os.environ['ORGNAME'] 23 | pkg = os.environ['PACKAGENAME'] 24 | n_keep = int(os.getenv('DEV_BUILD_N_KEEP', 10)) 25 | 26 | b = get_server_api(token=token) 27 | package = b.package(org, pkg) 28 | 29 | # sort releases by version number, oldest first 30 | sorted_by_version = sorted(package['releases'], 31 | key=lambda rel: parse_version(rel['version']), 32 | reverse=True 33 | ) 34 | to_delete = [] 35 | print("Currently have {n} versions online. Going to remove {x}.". 36 | format(n=len(sorted_by_version), x=len(sorted_by_version) - n_keep)) 37 | 38 | while len(sorted_by_version) > n_keep: 39 | to_delete.append(sorted_by_version.pop()) 40 | 41 | 42 | # remove old releases from anaconda.org 43 | for rel in to_delete: 44 | version = rel['version'] 45 | print("removing {version}".format(version=version)) 46 | b.remove_release(org, pkg, version) 47 | -------------------------------------------------------------------------------- /chainsaw/tests/test_acf.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | 21 | from __future__ import absolute_import 22 | import unittest 23 | import numpy as np 24 | 25 | from chainsaw.acf import acf 26 | from six.moves import range 27 | 28 | 29 | class TestACF(unittest.TestCase): 30 | def test(self): 31 | # generate some data 32 | data = np.random.rand(100, 3) 33 | 34 | testacf = acf(data) 35 | 36 | # direct computation of acf (single trajectory, three observables) 37 | N = data.shape[0] 38 | refacf = np.zeros(data.shape) 39 | meanfree = data - np.mean(data, axis=0) 40 | padded = np.concatenate((meanfree, np.zeros(data.shape)), axis=0) 41 | for tau in range(N): 42 | refacf[tau] = (padded[0:N, :]*padded[tau:N+tau, :]).sum(axis=0)/(N-tau) 43 | refacf /= refacf[0] # normalize 44 | 45 | np.testing.assert_allclose(refacf, testacf) 46 | 47 | if __name__ == "__main__": 48 | unittest.main() -------------------------------------------------------------------------------- /devtools/ci/appveyor/run_with_env.cmd: -------------------------------------------------------------------------------- 1 | :: To build extensions for 64 bit Python 3, we need to configure environment 2 | :: variables to use the MSVC 2010 C++ compilers from GRMSDKX_EN_DVD.iso of: 3 | :: MS Windows SDK for Windows 7 and .NET Framework 4 (SDK v7.1) 4 | :: 5 | :: To build extensions for 64 bit Python 2, we need to configure environment 6 | :: variables to use the MSVC 2008 C++ compilers from GRMSDKX_EN_DVD.iso of: 7 | :: MS Windows SDK for Windows 7 and .NET Framework 3.5 (SDK v7.0) 8 | :: 9 | :: 32 bit builds do not require specific environment configurations. 10 | :: 11 | :: Note: this script needs to be run with the /E:ON and /V:ON flags for the 12 | :: cmd interpreter, at least for (SDK v7.0) 13 | :: 14 | :: More details at: 15 | :: https://github.com/cython/cython/wiki/64BitCythonExtensionsOnWindows 16 | :: http://stackoverflow.com/a/13751649/163740 17 | :: 18 | :: Author: Olivier Grisel 19 | :: License: CC0 1.0 Universal: http://creativecommons.org/publicdomain/zero/1.0/ 20 | @ECHO OFF 21 | 22 | SET COMMAND_TO_RUN=%* 23 | SET WIN_SDK_ROOT=C:\Program Files\Microsoft SDKs\Windows 24 | 25 | SET MAJOR_PYTHON_VERSION="%CONDA_PY:~0,1%" 26 | IF %MAJOR_PYTHON_VERSION% == "2" ( 27 | SET WINDOWS_SDK_VERSION="v7.0" 28 | ) ELSE IF %MAJOR_PYTHON_VERSION% == "3" ( 29 | SET WINDOWS_SDK_VERSION="v7.1" 30 | ) ELSE ( 31 | ECHO Unsupported Python version: "%MAJOR_PYTHON_VERSION%" 32 | EXIT 1 33 | ) 34 | 35 | IF "%ARCH%"=="64" ( 36 | ECHO Configuring Windows SDK %WINDOWS_SDK_VERSION% for Python %MAJOR_PYTHON_VERSION% on a 64 bit architecture 37 | SET DISTUTILS_USE_SDK=1 38 | SET MSSdk=1 39 | "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Setup\WindowsSdkVer.exe" -q -version:%WINDOWS_SDK_VERSION% 40 | "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Bin\SetEnv.cmd" /x64 /release 41 | ECHO Executing: %COMMAND_TO_RUN% 42 | call %COMMAND_TO_RUN% || EXIT 1 43 | ) ELSE ( 44 | ECHO Using default MSVC build environment for 32 bit architecture 45 | ECHO Executing: %COMMAND_TO_RUN% 46 | call %COMMAND_TO_RUN% || EXIT 1 47 | ) 48 | -------------------------------------------------------------------------------- /chainsaw/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | r""" 21 | =============================================================================== 22 | data - Data and input/output utilities (:mod:`chainsaw.data`) 23 | =============================================================================== 24 | 25 | .. currentmodule: chainsaw.data 26 | 27 | Order parameters 28 | ================ 29 | 30 | .. autosummary:: 31 | :toctree: generated/ 32 | 33 | MDFeaturizer - selects and computes features from MD trajectories 34 | CustomFeature - define arbitrary function to extract features 35 | 36 | Reader 37 | ====== 38 | 39 | .. autosummary:: 40 | :toctree: generated/ 41 | 42 | FeatureReader - reads features via featurizer 43 | NumPyFileReader - reads numpy files 44 | PyCSVReader - reads tabulated ascii files 45 | DataInMemory - used if data is already available in mem 46 | 47 | """ 48 | from chainsaw.data.md.feature_reader import FeatureReader 49 | 50 | from .data_in_memory import DataInMemory 51 | from .numpy_filereader import NumPyFileReader 52 | from .py_csv_reader import PyCSVReader 53 | from .util.reader_utils import create_file_reader 54 | 55 | 56 | 57 | # TODO: if mdtraj avail, import FEatur3REader to register md extensions. 58 | -------------------------------------------------------------------------------- /chainsaw/tests/test_uniform_time.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Created on 09.04.2015 22 | 23 | @author: marscher 24 | ''' 25 | 26 | from __future__ import absolute_import 27 | import unittest 28 | 29 | import numpy as np 30 | 31 | from chainsaw import api 32 | from chainsaw.data.data_in_memory import DataInMemory 33 | 34 | 35 | class TestUniformTimeClustering(unittest.TestCase): 36 | 37 | def test_1d(self): 38 | x = np.random.random(1000) 39 | reader = DataInMemory(x) 40 | 41 | k = 2 42 | c = api.cluster_uniform_time(k=k) 43 | 44 | c.data_producer = reader 45 | c.parametrize() 46 | 47 | def test_2d(self): 48 | x = np.random.random((300, 3)) 49 | reader = DataInMemory(x) 50 | 51 | k = 2 52 | c = api.cluster_uniform_time(k=k) 53 | 54 | c.data_producer = reader 55 | c.parametrize() 56 | 57 | def test_2d_skip(self): 58 | x = np.random.random((300, 3)) 59 | reader = DataInMemory(x) 60 | 61 | k = 2 62 | c = api.cluster_uniform_time(k=k, skip=100) 63 | 64 | c.data_producer = reader 65 | c.parametrize() 66 | 67 | def test_big_k(self): 68 | x = np.random.random((300, 3)) 69 | reader = DataInMemory(x) 70 | k=151 71 | c = api.cluster_uniform_time(k=k) 72 | 73 | c.data_producer = reader 74 | c.parametrize() 75 | 76 | 77 | if __name__ == "__main__": 78 | unittest.main() -------------------------------------------------------------------------------- /chainsaw/util/contexts.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Created on 04.01.2016 3 | 4 | @author: marscher 5 | ''' 6 | from contextlib import contextmanager 7 | import random 8 | 9 | import numpy as np 10 | 11 | 12 | class conditional(object): 13 | """Wrap another context manager and enter it only if condition is true. 14 | """ 15 | 16 | def __init__(self, condition, contextmanager): 17 | self.condition = condition 18 | self.contextmanager = contextmanager 19 | 20 | def __enter__(self): 21 | if self.condition: 22 | return self.contextmanager.__enter__() 23 | 24 | def __exit__(self, *args): 25 | if self.condition: 26 | return self.contextmanager.__exit__(*args) 27 | 28 | 29 | @contextmanager 30 | def numpy_random_seed(seed=42): 31 | """ sets the random seed of numpy within the context. 32 | 33 | Example 34 | ------- 35 | >>> import numpy as np 36 | >>> with numpy_random_seed(seed=0): 37 | ... np.random.randint(1000) 38 | 684 39 | """ 40 | old_state = np.random.get_state() 41 | np.random.seed(seed) 42 | try: 43 | yield 44 | finally: 45 | np.random.set_state(old_state) 46 | 47 | 48 | @contextmanager 49 | def random_seed(seed=42): 50 | """ sets the random seed of Python within the context. 51 | 52 | Example 53 | ------- 54 | >>> import random 55 | >>> with random_seed(seed=0): 56 | ... random.randint(0, 1000) # doctest: +SKIP 57 | 864 58 | """ 59 | old_state = random.getstate() 60 | random.seed(seed) 61 | try: 62 | yield 63 | finally: 64 | random.setstate(old_state) 65 | 66 | 67 | @contextmanager 68 | def settings(**kwargs): 69 | """ apply given PyEMMA config values temporarily within the given context.""" 70 | from chainsaw import config 71 | # validate: 72 | valid_keys = config.keys() 73 | for k in kwargs.keys(): 74 | if k not in valid_keys: 75 | raise ValueError("not a valid settings: {key}".format(key=k)) 76 | 77 | old_settings = {} 78 | for k, v in kwargs.items(): 79 | old_settings[k] = getattr(config, k) 80 | setattr(config, k, v) 81 | 82 | yield 83 | 84 | # restore old settings 85 | for k, v in old_settings.items(): 86 | setattr(config, k, v) 87 | -------------------------------------------------------------------------------- /chainsaw/tests/test_mini_batch_kmeans.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import absolute_import 20 | import unittest 21 | from unittest import TestCase 22 | import numpy as np 23 | from chainsaw.api import cluster_mini_batch_kmeans 24 | 25 | 26 | class TestMiniBatchKmeans(TestCase): 27 | def test_3gaussian_1d_singletraj(self): 28 | # generate 1D data from three gaussians 29 | X = [np.random.randn(200) - 2.0, 30 | np.random.randn(300), 31 | np.random.randn(400) + 2.0] 32 | X = np.hstack(X) 33 | kmeans = cluster_mini_batch_kmeans(X, batch_size=0.5, k=100, max_iter=10000) 34 | cc = kmeans.clustercenters 35 | assert (np.any(cc < 1.0)) 36 | assert (np.any((cc > -1.0) * (cc < 1.0))) 37 | assert (np.any(cc > -1.0)) 38 | 39 | def test_3gaussian_2d_multitraj(self): 40 | # generate 1D data from three gaussians 41 | X1 = np.zeros((200, 2)) 42 | X1[:, 0] = np.random.randn(200) - 2.0 43 | X2 = np.zeros((300, 2)) 44 | X2[:, 0] = np.random.randn(300) 45 | X3 = np.zeros((400, 2)) 46 | X3[:, 0] = np.random.randn(400) + 2.0 47 | X = [X1, X2, X3] 48 | kmeans = cluster_mini_batch_kmeans(X, batch_size=0.5, k=100, max_iter=10000) 49 | cc = kmeans.clustercenters 50 | assert (np.any(cc < 1.0)) 51 | assert (np.any((cc > -1.0) * (cc < 1.0))) 52 | assert (np.any(cc > -1.0)) 53 | 54 | if __name__ == '__main__': 55 | unittest.main() -------------------------------------------------------------------------------- /chainsaw/util/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | ''' 20 | Created on May 26, 2014 21 | 22 | @author: marscher 23 | ''' 24 | 25 | 26 | class SpectralWarning(RuntimeWarning): 27 | pass 28 | 29 | 30 | class ImaginaryEigenValueWarning(SpectralWarning): 31 | pass 32 | 33 | 34 | class PrecisionWarning(RuntimeWarning): 35 | r""" 36 | This warning indicates that some operation in your code leads 37 | to a conversion of datatypes, which involves a loss/gain in 38 | precision. 39 | 40 | """ 41 | pass 42 | 43 | 44 | class NotConvergedWarning(RuntimeWarning): 45 | r""" 46 | This warning indicates that some iterative procedure has not 47 | converged or reached the maximum number of iterations implemented 48 | as a safe guard to prevent arbitrary many iterations in loops with 49 | a conditional termination criterion. 50 | 51 | """ 52 | pass 53 | 54 | 55 | class EfficiencyWarning(UserWarning): 56 | r"""Some operation or input data leads to a lack of efficiency""" 57 | pass 58 | 59 | 60 | class ParserWarning(UserWarning): 61 | """ Some user defined variable could not be parsed and is ignored/replaced. """ 62 | pass 63 | 64 | 65 | class ConfigDirectoryException(Exception): 66 | """ Some operation with the configuration directory went wrong. """ 67 | pass 68 | 69 | 70 | class Chainsaw_DeprecationWarning(UserWarning): 71 | """You are using a feature, which will be removed in a future release. You have been warned!""" 72 | pass -------------------------------------------------------------------------------- /chainsaw/util/change_notification.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | ''' 20 | Created on 14.01.2016 21 | 22 | @author: marscher 23 | ''' 24 | 25 | 26 | def inform_children_upon_change(f): 27 | """ decorator to call interface method '_stream_upon_change' of NotifyOnChangesMixin 28 | instances """ 29 | 30 | def _notify(self, *args, **kw): 31 | # first call the decorated function, then inform about the change. 32 | res = f(self, *args, **kw) 33 | self._stream_on_change() 34 | return res 35 | 36 | return _notify 37 | 38 | 39 | class NotifyOnChangesMixIn(object): 40 | #### interface to handle events 41 | 42 | @property 43 | def _stream_children(self): 44 | if not hasattr(self, "_stream_children_list"): 45 | self._stream_children_list = [] 46 | return self._stream_children_list 47 | 48 | def _stream_register_child(self, data_producer): 49 | """ should be called upon setting of data_producer """ 50 | self._stream_children.append(data_producer) 51 | 52 | def _stream_unregister_child(self, child): 53 | try: 54 | self._stream_children.remove(child) 55 | except ValueError: 56 | print("should not happen") 57 | 58 | def _stream_on_change(self): 59 | pass 60 | 61 | def _stream_inform_children_param_change(self): 62 | """ will inform all children about a parameter change in general """ 63 | for c in self._stream_children: 64 | c._stream_on_change() 65 | -------------------------------------------------------------------------------- /chainsaw/util/files.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Created on 17.02.2014 22 | 23 | @author: marscher 24 | ''' 25 | from __future__ import absolute_import, print_function 26 | 27 | import os 28 | import errno 29 | import tempfile 30 | import shutil 31 | 32 | 33 | def mkdir_p(path): 34 | try: 35 | os.makedirs(path) 36 | except OSError as exc: # Python >2.5 37 | if exc.errno == errno.EEXIST and os.path.isdir(path): 38 | pass 39 | else: 40 | raise 41 | 42 | 43 | class TemporaryDirectory(object): 44 | """Create and return a temporary directory. This has the same 45 | behavior as mkdtemp but can be used as a context manager. For 46 | example: 47 | 48 | Examples 49 | -------- 50 | >>> import os 51 | >>> with TemporaryDirectory() as tmp: 52 | ... path = os.path.join(tmp, "myfile.dat") 53 | ... fh = open(path, 'w') 54 | ... _ = fh.write('hello world') 55 | ... fh.close() 56 | 57 | Upon exiting the context, the directory and everything contained 58 | in it are removed. 59 | """ 60 | 61 | def __init__(self, prefix='', suffix='', dir=None): 62 | self.prefix = prefix 63 | self.suffix = suffix 64 | self.dir = dir 65 | self.tmpdir = None 66 | 67 | def __enter__(self): 68 | self.tmpdir = tempfile.mkdtemp(suffix=self.suffix, prefix=self.prefix, 69 | dir=self.dir) 70 | return self.tmpdir 71 | 72 | def __exit__(self, *args): 73 | shutil.rmtree(self.tmpdir, ignore_errors=True) -------------------------------------------------------------------------------- /devtools/ci/jenkins/update_versions_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # author: marscher 3 | # purpose: update version.json file on new software release. 4 | from __future__ import print_function 5 | import json 6 | import sys 7 | 8 | from argparse import ArgumentParser 9 | 10 | from six.moves.urllib.request import urlopen 11 | from distutils.version import LooseVersion as parse 12 | 13 | 14 | def hash_(self): 15 | return hash(self.vstring) 16 | parse.__hash__ = hash_ 17 | 18 | 19 | def make_version_dict(URL, version, url_prefix='v', latest=False): 20 | return {'version': version, 21 | # git tags : vx.y.z 22 | 'url': URL + '/' + url_prefix + version, 23 | 'latest': latest} 24 | 25 | 26 | def find_latest(versions): 27 | for v in versions: 28 | if v['latest'] == True: 29 | return v 30 | 31 | 32 | def main(argv=None): 33 | '''Command line options.''' 34 | 35 | if argv is None: 36 | argv = sys.argv 37 | else: 38 | sys.argv.extend(argv) 39 | 40 | parser = ArgumentParser() 41 | parser.add_argument('-u', '--url', dest='url', required=True, help="base url (has to contain versions json)") 42 | parser.add_argument('-o', '--output', dest='output') 43 | parser.add_argument('-a', '--add_version', dest='version') 44 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true') 45 | parser.add_argument('-l', '--latest-version', dest='latest', action='store_true') 46 | 47 | args = parser.parse_args() 48 | 49 | URL = args.url 50 | # get dict 51 | versions = json.load(urlopen(URL + '/versions.json')) 52 | # add new version 53 | if args.version: 54 | versions.append(make_version_dict(URL, args.version)) 55 | 56 | # create Version objects to compare them 57 | version_objs = [parse(s['version']) for s in versions] 58 | 59 | # unify and sort 60 | version_objs = set(version_objs) 61 | version_objs = sorted(list(version_objs)) 62 | 63 | versions = [make_version_dict(URL, str(v)) for v in version_objs if v != 'devel'] 64 | 65 | # last element should be the highest version 66 | versions[-1]['latest'] = True 67 | versions.append(make_version_dict(URL, 'devel', '', False)) 68 | 69 | if args.verbose: 70 | print ("new versions json:", versions) 71 | 72 | if args.latest: 73 | print(find_latest(versions)['version']) 74 | return 0 75 | 76 | if args.output: 77 | with open(args.output, 'w') as v: 78 | json.dump(versions, v) 79 | v.flush() 80 | 81 | if __name__ == '__main__': 82 | sys.exit(main()) 83 | -------------------------------------------------------------------------------- /chainsaw/util/indices.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import absolute_import 20 | import numpy as np 21 | 22 | 23 | def combinations(seq, k): 24 | """ Return j length subsequences of elements from the input iterable. 25 | 26 | This version uses Numpy/Scipy and should be preferred over itertools. It avoids 27 | the creation of all intermediate Python objects. 28 | 29 | Examples 30 | -------- 31 | 32 | >>> import numpy as np 33 | >>> from itertools import combinations as iter_comb 34 | >>> x = np.arange(3) 35 | >>> c1 = combinations(x, 2) 36 | >>> print(c1) # doctest: +NORMALIZE_WHITESPACE 37 | [[0 1] 38 | [0 2] 39 | [1 2]] 40 | >>> c2 = np.array(tuple(iter_comb(x, 2))) 41 | >>> print(c2) # doctest: +NORMALIZE_WHITESPACE 42 | [[0 1] 43 | [0 2] 44 | [1 2]] 45 | """ 46 | from itertools import combinations as _combinations, chain 47 | from scipy.misc import comb 48 | 49 | count = comb(len(seq), k, exact=True) 50 | res = np.fromiter(chain.from_iterable(_combinations(seq, k)), 51 | int, count=count*k) 52 | return res.reshape(-1, k) 53 | 54 | 55 | def product(*arrays): 56 | """ Generate a cartesian product of input arrays. 57 | 58 | Parameters 59 | ---------- 60 | arrays : list of array-like 61 | 1-D arrays to form the cartesian product of. 62 | 63 | Returns 64 | ------- 65 | out : ndarray 66 | 2-D array of shape (M, len(arrays)) containing cartesian products 67 | formed of input arrays. 68 | 69 | """ 70 | arrays = [np.asarray(x) for x in arrays] 71 | shape = (len(x) for x in arrays) 72 | dtype = arrays[0].dtype 73 | 74 | ix = np.indices(shape) 75 | ix = ix.reshape(len(arrays), -1).T 76 | 77 | out = np.empty_like(ix, dtype=dtype) 78 | 79 | for n, _ in enumerate(arrays): 80 | out[:, n] = arrays[n][ix[:, n]] 81 | 82 | return out 83 | -------------------------------------------------------------------------------- /chainsaw/tests/util.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | import numpy as np 3 | import mdtraj 4 | import pkg_resources 5 | import os 6 | import logging 7 | 8 | 9 | class MockLoggingHandler(logging.Handler): 10 | """Mock logging handler to check for expected logs.""" 11 | 12 | def __init__(self, *args, **kwargs): 13 | self.reset() 14 | logging.Handler.__init__(self, *args, **kwargs) 15 | 16 | def emit(self, record): 17 | self.messages[record.levelname.lower()].append(record.getMessage()) 18 | 19 | def reset(self): 20 | self.messages = { 21 | 'debug': [], 22 | 'info': [], 23 | 'warning': [], 24 | 'error': [], 25 | 'critical': [], 26 | } 27 | 28 | def _setup_testing(): 29 | # setup function for testing 30 | from chainsaw import config 31 | # do not cache trajectory info in user directory (temp traj files) 32 | config.use_trajectory_lengths_cache = False 33 | config.show_progress_bars = False 34 | 35 | def _monkey_patch_testing_apply_setting(): 36 | """ this monkey patches the init methods of unittest.TestCase and doctest.DocTestFinder, 37 | in order to apply internal settings. """ 38 | import unittest 39 | import doctest 40 | _old_init = unittest.TestCase.__init__ 41 | def _new_init(self, *args, **kwargs): 42 | _old_init(self, *args, **kwargs) 43 | _setup_testing() 44 | 45 | unittest.TestCase.__init__ = _new_init 46 | 47 | _old_init_doc_test_finder = doctest.DocTestFinder.__init__ 48 | 49 | def _patched_init(self, *args, **kw): 50 | _setup_testing() 51 | _old_init_doc_test_finder(self, *args, **kw) 52 | 53 | doctest.DocTestFinder.__init__ = _patched_init 54 | 55 | 56 | def get_bpti_test_data(): 57 | import pkg_resources 58 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep 59 | from glob import glob 60 | xtcfiles = glob(path + '/bpti_0*.xtc') 61 | pdbfile = os.path.join(path, 'bpti_ca.pdb') 62 | assert xtcfiles, xtcfiles 63 | assert pdbfile, pdbfile 64 | 65 | return xtcfiles, pdbfile 66 | 67 | 68 | def get_top(): 69 | return pkg_resources.resource_filename(__name__, 'data/test.pdb') 70 | 71 | 72 | def create_traj(top=None, format='.xtc', dir=None, length=1000, start=0): 73 | trajfile = tempfile.mktemp(suffix=format, dir=dir) 74 | xyz = np.arange(start*3*3, (start+length) * 3 * 3) 75 | xyz = xyz.reshape((-1, 3, 3)) 76 | if top is None: 77 | top = get_top() 78 | 79 | t = mdtraj.load(top) 80 | t.xyz = xyz 81 | t.unitcell_vectors = np.array(length * [[0, 0, 1], [0, 1, 0], [1, 0, 0]]).reshape(length, 3, 3) 82 | t.time = np.arange(length) 83 | t.save(trajfile) 84 | 85 | return trajfile, xyz, length 86 | -------------------------------------------------------------------------------- /chainsaw/tests/test_cluster_samples.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | """ 21 | Test the save_trajs function of the coordinates API by comparing 22 | the direct, sequential retrieval of frames via mdtraj.load_frame() vs 23 | the retrival via save_trajs 24 | @author: gph82, clonker 25 | """ 26 | 27 | from __future__ import absolute_import 28 | 29 | import unittest 30 | 31 | import numpy as np 32 | import chainsaw as coor 33 | 34 | 35 | class TestClusterSamples(unittest.TestCase): 36 | 37 | def setUp(self): 38 | self.input_trajs = [[0,1,2], 39 | [3,4,5], 40 | [6,7,8], 41 | [0,1,2], 42 | [3,4,5], 43 | [6,7,8]] 44 | self.cluster_obj = coor.cluster_regspace(data=self.input_trajs, dmin=.5) 45 | 46 | def test_index_states(self): 47 | # Test that the catalogue is being set up properly 48 | 49 | # The assingment-catalogue is easy to see from the above dtrajs 50 | ref = [[[0,0],[3,0]], # appearances of the 1st cluster 51 | [[0,1],[3,1]], # appearances of the 2nd cluster 52 | [[0,2],[3,2]], # appearances of the 3rd cluster 53 | [[1,0],[4,0]], # ..... 54 | [[1,1],[4,1]], 55 | [[1,2],[4,2]], 56 | [[2,0],[5,0]], 57 | [[2,1],[5,1]], 58 | [[2,2],[5,2]], 59 | ] 60 | 61 | for cc in np.arange(self.cluster_obj.n_clusters): 62 | assert np.allclose(self.cluster_obj.index_clusters[cc], ref[cc]) 63 | 64 | def test_sample_indexes_by_state(self): 65 | samples = self.cluster_obj.sample_indexes_by_cluster(np.arange(self.cluster_obj.n_clusters), 10) 66 | 67 | # For each sample, check that you're actually retrieving the i-th center 68 | for ii, isample in enumerate(samples): 69 | assert np.in1d([self.cluster_obj.dtrajs[pair[0]][pair[1]] for pair in isample],ii).all() 70 | 71 | 72 | if __name__ == "__main__": 73 | unittest.main() -------------------------------------------------------------------------------- /chainsaw/_ext/variational_estimators/lagged_correlation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | __author__ = 'noe' 4 | 5 | import numpy as np 6 | 7 | class LaggedCorrelation(object): 8 | 9 | def __init__(self, output_dimension, tau=1): 10 | """ Computes correlation matrices C0 and Ctau from a bunch of trajectories 11 | 12 | Parameters 13 | ---------- 14 | output_dimension: int 15 | Number of basis functions. 16 | tau: int 17 | Lag time 18 | 19 | """ 20 | self.tau = tau 21 | self.output_dimension = output_dimension 22 | # Initialize the two correlation matrices: 23 | self.Ct = np.zeros((self.output_dimension, self.output_dimension)) 24 | self.C0 = np.zeros((self.output_dimension, self.output_dimension)) 25 | # Create counter for the frames used for C0, Ct: 26 | self.nC0 = 0 27 | self.nCt = 0 28 | 29 | def add(self, X): 30 | """ Adds trajectory to the running estimate for computing mu, C0 and Ct: 31 | 32 | Parameters 33 | ---------- 34 | X: ndarray (T,N) 35 | basis function trajectory of T time steps for N basis functions. 36 | 37 | """ 38 | # Raise an error if output dimension is wrong: 39 | if not (X.shape[1] == self.output_dimension): 40 | raise Exception("Number of basis functions is incorrect."+ 41 | "Got %d, expected %d."%(X.shape[1], 42 | self.output_dimension)) 43 | # Print message if number of time steps is too small: 44 | if X.shape[0] <= self.tau: 45 | raise ValueError("Number of time steps is too small.") 46 | 47 | # Get the time-lagged data: 48 | Y1 = X[self.tau:,:] 49 | # Remove the last tau frames from X: 50 | Y2 = X[:-self.tau,:] 51 | # Get the number of time steps in this trajectory: 52 | TX = 1.0*Y1.shape[0] 53 | # Update time-lagged correlation matrix: 54 | self.Ct += np.dot(Y1.T,Y2) 55 | self.nCt += TX 56 | # Update the instantaneous correlation matrix: 57 | self.C0 += np.dot(Y1.T,Y1) + np.dot(Y2.T,Y2) 58 | self.nC0 += 2*TX 59 | 60 | def GetC0(self): 61 | """ Returns the current estimate of C0: 62 | 63 | Returns 64 | ------- 65 | C0: ndarray (N,N) 66 | time instantaneous correlation matrix of N basis function. 67 | 68 | """ 69 | return 0.5*(self.C0 + self.C0.T)/(self.nC0 - 1) 70 | 71 | def GetCt(self): 72 | """ Returns the current estimate of Ctau 73 | 74 | Returns 75 | ------- 76 | Ct: ndarray (N,N) 77 | time lagged correlation matrix of N basis function. 78 | 79 | """ 80 | return 0.5*(self.Ct + self.Ct.T)/(self.nCt - 1) 81 | -------------------------------------------------------------------------------- /chainsaw/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | r""" 21 | .. currentmodule:: chainsaw 22 | 23 | User API 24 | ======== 25 | 26 | **Trajectory input/output and featurization** 27 | 28 | .. autosummary:: 29 | :toctree: generated/ 30 | 31 | featurizer 32 | load 33 | source 34 | pipeline 35 | discretizer 36 | save_traj 37 | save_trajs 38 | 39 | **Coordinate and feature transformations** 40 | 41 | .. autosummary:: 42 | :toctree: generated/ 43 | 44 | pca 45 | tica 46 | 47 | **Clustering Algorithms** 48 | 49 | .. autosummary:: 50 | :toctree: generated/ 51 | 52 | cluster_kmeans 53 | cluster_mini_batch_kmeans 54 | cluster_regspace 55 | cluster_uniform_time 56 | assign_to_centers 57 | 58 | Classes 59 | ======= 60 | **Coordinate classes** encapsulating complex functionality. You don't need to 61 | construct these classes yourself, as this is done by the user API functions above. 62 | Find here a documentation how to extract features from them. 63 | 64 | **I/O and Featurization** 65 | 66 | .. autosummary:: 67 | :toctree: generated/ 68 | 69 | data.MDFeaturizer 70 | data.CustomFeature 71 | 72 | **Transformation estimators** 73 | 74 | .. autosummary:: 75 | :toctree: generated/ 76 | 77 | transform.PCA 78 | transform.TICA 79 | 80 | **Clustering algorithms** 81 | 82 | .. autosummary:: 83 | :toctree: generated/ 84 | 85 | clustering.KmeansClustering 86 | clustering.MiniBatchKmeansClustering 87 | clustering.RegularSpaceClustering 88 | clustering.UniformTimeClustering 89 | 90 | **Transformers** 91 | 92 | .. autosummary:: 93 | :toctree: generated/ 94 | 95 | transform.transformer.StreamingTransformer 96 | pipelines.Pipeline 97 | 98 | **Discretization** 99 | 100 | .. autosummary:: 101 | :toctree: generated/ 102 | 103 | clustering.AssignCenters 104 | 105 | 106 | """ 107 | from .util import config 108 | from .api import * 109 | 110 | from ._version import get_versions 111 | __version__ = get_versions()['version'] 112 | version = __version__ 113 | del get_versions 114 | 115 | from chainsaw.tests.util import _monkey_patch_testing_apply_setting 116 | _monkey_patch_testing_apply_setting() 117 | del _monkey_patch_testing_apply_setting -------------------------------------------------------------------------------- /chainsaw/util/stat.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | 21 | from __future__ import absolute_import 22 | import numpy as np 23 | from chainsaw.util.annotators import deprecated 24 | from six.moves import zip 25 | 26 | __author__ = 'Fabian Paul' 27 | __all__ = ['histogram'] 28 | 29 | 30 | def histogram(transform, dimensions, nbins): 31 | '''Computes the N-dimensional histogram of the transformed data. 32 | 33 | Parameters 34 | ---------- 35 | transform : chainsaw.transfrom.Transformer object 36 | transform that provides the input data 37 | dimensions : tuple of indices 38 | indices of the dimensions you want to examine 39 | nbins : tuple of ints 40 | number of bins along each dimension 41 | 42 | Returns 43 | ------- 44 | counts : (bins[0],bins[1],...) ndarray of ints 45 | counts compatible with pyplot.pcolormesh and pyplot.bar 46 | edges : list of (bins[i]) ndarrays 47 | bin edges compatible with pyplot.pcolormesh and pyplot.bar, 48 | see below. 49 | 50 | Examples 51 | -------- 52 | 53 | >>> import matplotlib.pyplot as plt # doctest: +SKIP 54 | 55 | Only for ipython notebook 56 | >> %matplotlib inline # doctest: +SKIP 57 | 58 | >>> counts, edges=histogram(transform, dimensions=(0,1), nbins=(20, 30)) # doctest: +SKIP 59 | >>> plt.pcolormesh(edges[0], edges[1], counts.T) # doctest: +SKIP 60 | 61 | >>> counts, edges=histogram(transform, dimensions=(1,), nbins=(50,)) # doctest: +SKIP 62 | >>> plt.bar(edges[0][:-1], counts, width=edges[0][1:]-edges[0][:-1]) # doctest: +SKIP 63 | ''' 64 | maximum = np.ones(len(dimensions)) * (-np.inf) 65 | minimum = np.ones(len(dimensions)) * np.inf 66 | # compute min and max 67 | for _, chunk in transform: 68 | maximum = np.max( 69 | np.vstack(( 70 | maximum, 71 | np.max(chunk[:, dimensions], axis=0))), 72 | axis=0) 73 | minimum = np.min( 74 | np.vstack(( 75 | minimum, 76 | np.min(chunk[:, dimensions], axis=0))), 77 | axis=0) 78 | # define bins 79 | bins = [np.linspace(m, M, num=n) 80 | for m, M, n in zip(minimum, maximum, nbins)] 81 | res = np.zeros(np.array(nbins) - 1) 82 | # compute actual histogram 83 | for _, chunk in transform: 84 | part, _ = np.histogramdd(chunk[:, dimensions], bins=bins) 85 | res += part 86 | return res, bins -------------------------------------------------------------------------------- /chainsaw/clustering/include/clustering.h: -------------------------------------------------------------------------------- 1 | /* * This file is part of PyEMMA. 2 | * 3 | * Copyright (c) 2015, 2014 Computational Molecular Biology Group 4 | * 5 | * PyEMMA is free software: you can redistribute it and/or modify 6 | * it under the terms of the GNU Lesser General Public License as published by 7 | * the Free Software Foundation, either version 3 of the License, or 8 | * (at your option) any later version. 9 | * 10 | * This program is distributed in the hope that it will be useful, 11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | * GNU General Public License for more details. 14 | * 15 | * You should have received a copy of the GNU Lesser General Public License 16 | * along with this program. If not, see . 17 | */ 18 | 19 | #ifndef _CLUSTERING_H_ 20 | #define _CLUSTERING_H_ 21 | #ifdef __cplusplus 22 | extern "C" { 23 | #endif 24 | 25 | #include 26 | #define PY_ARRAY_UNIQUE_SYMBOL pyemma_clustering_ARRAY_API 27 | #include 28 | 29 | #ifdef _OPENMP 30 | #include 31 | #endif /* _OPENMP */ 32 | 33 | #include 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | 41 | 42 | #ifndef NULL 43 | #warning null defined manually... 44 | #define NULL 0x0 45 | #endif 46 | 47 | #if defined(__GNUC__) && ((__GNUC__ > 3) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1)) 48 | # define SKP_restrict __restrict 49 | #elif defined(_MSC_VER) && _MSC_VER >= 1400 50 | # define SKP_restrict __restrict 51 | #else 52 | # define SKP_restrict 53 | #endif 54 | 55 | #define ASSIGN_SUCCESS 0 56 | #define ASSIGN_ERR_NO_MEMORY 1 57 | #define ASSIGN_ERR_INVALID_METRIC 2 58 | 59 | static char ASSIGN_USAGE[] = "assign(chunk, centers, dtraj, metric)\n"\ 60 | "Assigns frames in `chunk` to the closest cluster centers.\n"\ 61 | "\n"\ 62 | "Parameters\n"\ 63 | "----------\n"\ 64 | "chunk : (N,M) C-style contiguous and behaved ndarray of np.float32\n"\ 65 | " (input) array of N frames, each frame having dimension M\n"\ 66 | "centers : (M,K) ndarray-like of np.float32\n"\ 67 | " (input) Non-empty array-like of cluster centers.\n"\ 68 | "dtraj : (N) ndarray of np.int64\n"\ 69 | " (output) discretized trajectory\n"\ 70 | " dtraj[i]=argmin{ d(chunk[i,:],centers[j,:]) | j in 0...(K-1) }\n"\ 71 | " where d is the metric that is specified with the argument `metric`.\n"\ 72 | "metric : string\n"\ 73 | " (input) One of \"euclidean\" or \"minRMSD\" (case sensitive).\n"\ 74 | "\n"\ 75 | "Returns \n"\ 76 | "-------\n"\ 77 | "None\n"\ 78 | "\n"\ 79 | "Note\n"\ 80 | "----\n"\ 81 | "This function uses the minRMSD implementation of mdtraj."; 82 | 83 | // euclidean metric 84 | float euclidean_distance(float *SKP_restrict a, float *SKP_restrict b, size_t n, float *buffer_a, float *buffer_b, float*dummy); 85 | // minRMSD metric 86 | float minRMSD_distance(float *SKP_restrict a, float *SKP_restrict b, size_t n, float *SKP_restrict buffer_a, float *SKP_restrict buffer_b, 87 | float* pre_calc_trace_a); 88 | 89 | // assignment to cluster centers from python 90 | PyObject *assign(PyObject *self, PyObject *args); 91 | // assignment to cluster centers from c 92 | int c_assign(float *chunk, float *centers, npy_int32 *dtraj, char* metric, 93 | Py_ssize_t N_frames, Py_ssize_t N_centers, Py_ssize_t dim, int n_threads); 94 | 95 | #ifdef __cplusplus 96 | } 97 | #endif 98 | #endif 99 | -------------------------------------------------------------------------------- /chainsaw/base/estimator.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import absolute_import, print_function 20 | 21 | from chainsaw._ext.sklearn_base import BaseEstimator as _BaseEstimator 22 | from chainsaw.base.loggable import Loggable as _Loggable 23 | 24 | __author__ = 'noe, marscher' 25 | 26 | 27 | class Estimator(_BaseEstimator, _Loggable): 28 | """ Base class for Chainsaws estimators """ 29 | # flag indicating if estimator's estimate method has been called 30 | _estimated = False 31 | 32 | def estimate(self, X, **params): 33 | """ Estimates the model given the data X 34 | 35 | Parameters 36 | ---------- 37 | X : object 38 | A reference to the data from which the model will be estimated 39 | params : dict 40 | New estimation parameter values. The parameters must that have been 41 | announced in the __init__ method of this estimator. The present 42 | settings will overwrite the settings of parameters given in the 43 | __init__ method, i.e. the parameter values after this call will be 44 | those that have been used for this estimation. Use this option if 45 | only one or a few parameters change with respect to 46 | the __init__ settings for this run, and if you don't need to 47 | remember the original settings of these changed parameters. 48 | 49 | Returns 50 | ------- 51 | estimator : object 52 | The estimated estimator with the model being available. 53 | 54 | """ 55 | # set params 56 | if params: 57 | self.set_params(**params) 58 | self._model = self._estimate(X) 59 | self._estimated = True 60 | return self 61 | 62 | def _estimate(self, X): 63 | raise NotImplementedError( 64 | 'You need to overload the _estimate() method in your Estimator implementation!') 65 | 66 | def fit(self, X): 67 | """Estimates parameters - for compatibility with sklearn. 68 | 69 | Parameters 70 | ---------- 71 | X : object 72 | A reference to the data from which the model will be estimated 73 | 74 | Returns 75 | ------- 76 | estimator : object 77 | The estimator (self) with estimated model. 78 | 79 | """ 80 | self.estimate(X) 81 | return self 82 | 83 | @property 84 | def model(self): 85 | """The model estimated by this Estimator""" 86 | try: 87 | return self._model 88 | except AttributeError: 89 | raise AttributeError( 90 | 'Model has not yet been estimated. Call estimate(X) or fit(X) first') 91 | -------------------------------------------------------------------------------- /setup_util.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of MSMTools. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # MSMTools is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | """ 20 | utility functions for python setup 21 | """ 22 | import tempfile 23 | import os 24 | import sys 25 | import shutil 26 | from distutils.ccompiler import new_compiler 27 | 28 | # From http://stackoverflow.com/questions/ 29 | # 7018879/disabling-output-when-compiling-with-distutils 30 | def hasfunction(cc, funcname): 31 | tmpdir = tempfile.mkdtemp(prefix='hasfunction-') 32 | devnull = oldstderr = None 33 | try: 34 | try: 35 | fname = os.path.join(tmpdir, 'funcname.c') 36 | f = open(fname, 'w') 37 | f.write('int main(void) {\n') 38 | f.write(' %s();\n' % funcname) 39 | f.write('}\n') 40 | f.close() 41 | # Redirect stderr to /dev/null to hide any error messages 42 | # from the compiler. 43 | # This will have to be changed if we ever have to check 44 | # for a function on Windows. 45 | devnull = open('/dev/null', 'w') 46 | oldstderr = os.dup(sys.stderr.fileno()) 47 | os.dup2(devnull.fileno(), sys.stderr.fileno()) 48 | objects = cc.compile([fname], output_dir=tmpdir) 49 | cc.link_executable(objects, os.path.join(tmpdir, "a.out")) 50 | except: 51 | return False 52 | return True 53 | finally: 54 | if oldstderr is not None: 55 | os.dup2(oldstderr, sys.stderr.fileno()) 56 | if devnull is not None: 57 | devnull.close() 58 | shutil.rmtree(tmpdir) 59 | 60 | 61 | def detect_openmp(): 62 | compiler = new_compiler() 63 | hasopenmp = hasfunction(compiler, 'omp_get_num_threads') 64 | needs_gomp = hasopenmp 65 | if not hasopenmp: 66 | compiler.add_library('gomp') 67 | hasopenmp = hasfunction(compiler, 'omp_get_num_threads') 68 | needs_gomp = hasopenmp 69 | return hasopenmp, needs_gomp 70 | 71 | 72 | def getSetuptoolsError(): 73 | bootstrap_setuptools = """\ 74 | python2.7 -c "import urllib2; 75 | url=\'https://bootstrap.pypa.io/ez_setup.py\';\n 76 | exec urllib2.urlopen(url).read()\"""" 77 | cmd = ((80 * '=') + '\n' + bootstrap_setuptools + '\n' + (80 * '=')) 78 | s = 'You can use the following command to upgrade/install it:\n%s' % cmd 79 | return s 80 | 81 | 82 | class lazy_cythonize(list): 83 | """evaluates extension list lazyly. 84 | pattern taken from http://tinyurl.com/qb8478q""" 85 | def __init__(self, callback): 86 | self._list, self.callback = None, callback 87 | def c_list(self): 88 | if self._list is None: self._list = self.callback() 89 | return self._list 90 | def __iter__(self): 91 | for e in self.c_list(): yield e 92 | def __getitem__(self, ii): return self.c_list()[ii] 93 | def __len__(self): return len(self.c_list()) -------------------------------------------------------------------------------- /chainsaw/base/loggable.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | ''' 20 | Created on 30.08.2015 21 | 22 | @author: marscher 23 | ''' 24 | from __future__ import absolute_import 25 | import logging 26 | import weakref 27 | from itertools import count 28 | 29 | __all__ = ['Loggable'] 30 | 31 | 32 | class Loggable(object): 33 | # counting instances of Loggable, incremented by name property. 34 | __ids = count(0) 35 | # holds weak references to instances of this, to clean up logger instances. 36 | __refs = {} 37 | 38 | _loglevel_DEBUG = logging.DEBUG 39 | _loglevel_INFO = logging.INFO 40 | _loglevel_WARN = logging.WARN 41 | _loglevel_ERROR = logging.ERROR 42 | _loglevel_CRITICAL = logging.CRITICAL 43 | 44 | @property 45 | def name(self): 46 | """The name of this instance""" 47 | try: 48 | return self._name 49 | except AttributeError: 50 | self._name = "%s.%s[%i]" % (self.__module__, 51 | self.__class__.__name__, 52 | next(Loggable.__ids)) 53 | return self._name 54 | 55 | @property 56 | def logger(self): 57 | """The logger for this class instance """ 58 | try: 59 | return self._logger_instance 60 | except AttributeError: 61 | self.__create_logger() 62 | return self._logger_instance 63 | 64 | @property 65 | def _logger(self): 66 | return self.logger 67 | 68 | def _logger_is_active(self, level): 69 | """ @param level: int log level (debug=10, info=20, warn=30, error=40, critical=50)""" 70 | return self.logger.level >= level 71 | 72 | def __create_logger(self): 73 | _weak_logger_refs = Loggable.__refs 74 | # creates a logger based on the the attribe "name" of self 75 | self._logger_instance = logging.getLogger(self.name) 76 | 77 | # store a weakref to this instance to clean the logger instance. 78 | logger_id = id(self._logger_instance) 79 | r = weakref.ref(self, Loggable._cleanup_logger(logger_id, self.name)) 80 | _weak_logger_refs[logger_id] = r 81 | 82 | @staticmethod 83 | def _cleanup_logger(logger_id, logger_name): 84 | # callback function used in conjunction with weakref.ref 85 | # removes logger from root manager 86 | 87 | def remove_logger(weak): 88 | d = logging.getLogger().manager.loggerDict 89 | del d[logger_name] 90 | del Loggable.__refs[logger_id] 91 | return remove_logger 92 | 93 | def __getstate__(self): 94 | # do not pickle the logger instance 95 | d = dict(self.__dict__) 96 | try: 97 | del d['_logger_instance'] 98 | except KeyError: 99 | pass 100 | return d 101 | -------------------------------------------------------------------------------- /devtools/ci/appveyor/transform_xunit_to_appveyor.xsl: -------------------------------------------------------------------------------- 1 | 12 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | Fail 69 | Skip 70 | Pass 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /chainsaw/clustering/assign.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | ''' 20 | Created on 18.02.2015 21 | 22 | @author: marscher 23 | ''' 24 | 25 | from __future__ import absolute_import 26 | 27 | import numpy as np 28 | import six 29 | 30 | from .interface import AbstractClustering 31 | from chainsaw.util.annotators import fix_docs 32 | 33 | 34 | @fix_docs 35 | class AssignCenters(AbstractClustering): 36 | 37 | """Assigns given (pre-calculated) cluster centers. If you already have 38 | cluster centers from somewhere, you use this class to assign your data to it. 39 | 40 | Parameters 41 | ---------- 42 | clustercenters : path to file (csv) or npyfile or ndarray 43 | cluster centers to use in assignment of data 44 | metric : str 45 | metric to use during clustering ('euclidean', 'minRMSD') 46 | stride : int 47 | stride 48 | n_jobs : int or None, default None 49 | Number of threads to use during assignment of the data. 50 | If None, all available CPUs will be used. 51 | skip : int, default=0 52 | skip the first initial n frames per trajectory. 53 | Examples 54 | -------- 55 | Assuming you have stored your centers in a CSV file: 56 | 57 | >>> from chainsaw.clustering import AssignCenters 58 | >>> from chainsaw import pipeline 59 | >>> reader = ... # doctest: +SKIP 60 | >>> assign = AssignCenters('my_centers.dat') # doctest: +SKIP 61 | >>> disc = pipeline(reader, cluster=assign) # doctest: +SKIP 62 | >>> disc.parametrize() # doctest: +SKIP 63 | 64 | """ 65 | 66 | def __init__(self, clustercenters, metric='euclidean', stride=1, n_jobs=None, skip=0): 67 | super(AssignCenters, self).__init__(metric=metric, n_jobs=n_jobs) 68 | 69 | if isinstance(clustercenters, six.string_types): 70 | from chainsaw.data import create_file_reader 71 | reader = create_file_reader(clustercenters, None, None) 72 | clustercenters = reader.get_output()[0] 73 | else: 74 | clustercenters = np.array(clustercenters, dtype=np.float32, order='C') 75 | 76 | # sanity check. 77 | if not clustercenters.ndim == 2: 78 | raise ValueError('cluster centers have to be 2d') 79 | 80 | self.set_params(clustercenters=clustercenters, metric=metric, stride=stride, skip=skip) 81 | 82 | # since we provided centers, no estimation is required. 83 | self._estimated = True 84 | 85 | def describe(self): 86 | return "[{name} centers shape={shape}]".format(name=self.name, shape=self.clustercenters.shape) 87 | 88 | @AbstractClustering.data_producer.setter 89 | def data_producer(self, dp): 90 | # check dimensions 91 | dim = self.clustercenters.shape[1] 92 | if not dim == dp.dimension(): 93 | raise ValueError('cluster centers have wrong dimension. Have dim=%i' 94 | ', but input has %i' % (dim, dp.dimension())) 95 | AbstractClustering.data_producer.fset(self, dp) 96 | 97 | def _estimate(self, iterable, **kw): 98 | old_source = self._data_producer 99 | self.data_producer = iterable 100 | try: 101 | self.assign(None, self.stride) 102 | finally: 103 | self.data_producer = old_source 104 | 105 | return self 106 | -------------------------------------------------------------------------------- /chainsaw/tests/test_api_load.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Created on 14.04.2015 22 | 23 | @author: marscher 24 | ''' 25 | 26 | from __future__ import absolute_import 27 | # unicode compat py2/3 28 | from six import text_type 29 | import unittest 30 | from chainsaw.api import load 31 | import os 32 | 33 | import numpy as np 34 | from chainsaw import api 35 | 36 | import pkg_resources 37 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep 38 | 39 | pdb_file = os.path.join(path, 'bpti_ca.pdb') 40 | traj_files = [ 41 | os.path.join(path, 'bpti_001-033.xtc'), 42 | os.path.join(path, 'bpti_067-100.xtc') 43 | ] 44 | 45 | 46 | class TestAPILoad(unittest.TestCase): 47 | 48 | @classmethod 49 | def setUpClass(cls): 50 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep 51 | cls.bpti_pdbfile = os.path.join(path, 'bpti_ca.pdb') 52 | extensions = ['.xtc', '.binpos', '.dcd', '.h5', '.lh5', '.nc', '.netcdf', '.trr'] 53 | cls.bpti_mini_files = [os.path.join(path, 'bpti_mini%s' % ext) for ext in extensions] 54 | 55 | def testUnicodeString_without_featurizer(self): 56 | filename = text_type(traj_files[0]) 57 | 58 | with self.assertRaises(ValueError): 59 | load(filename) 60 | 61 | def testUnicodeString(self): 62 | filename = text_type(traj_files[0]) 63 | features = api.featurizer(pdb_file) 64 | 65 | load(filename, features) 66 | 67 | def test_various_formats_load(self): 68 | chunksizes = [0, 13] 69 | X = None 70 | bpti_mini_previous = None 71 | for cs in chunksizes: 72 | for bpti_mini in self.bpti_mini_files: 73 | Y = api.load(bpti_mini, top=self.bpti_pdbfile, chunk_size=cs) 74 | if X is not None: 75 | np.testing.assert_array_almost_equal(X, Y, err_msg='Comparing %s to %s failed for chunksize %s' 76 | % (bpti_mini, bpti_mini_previous, cs)) 77 | X = Y 78 | bpti_mini_previous = bpti_mini 79 | 80 | def test_load_traj(self): 81 | filename = traj_files[0] 82 | features = api.featurizer(pdb_file) 83 | res = load(filename, features) 84 | 85 | self.assertEqual(type(res), np.ndarray) 86 | 87 | def test_load_trajs(self): 88 | features = api.featurizer(pdb_file) 89 | res = load(traj_files, features) 90 | 91 | self.assertEqual(type(res), list) 92 | self.assertTrue(all(type(x) == np.ndarray for x in res)) 93 | 94 | def test_with_trajs_without_featurizer_or_top(self): 95 | 96 | with self.assertRaises(ValueError): 97 | load(traj_files) 98 | 99 | output = load(traj_files, top=pdb_file) 100 | self.assertEqual(type(output), list) 101 | self.assertEqual(len(output), len(traj_files)) 102 | 103 | def test_non_existant_input(self): 104 | input_files = [traj_files[0], "does_not_exist_for_sure"] 105 | 106 | with self.assertRaises(ValueError): 107 | load(trajfiles=input_files, top=pdb_file) 108 | 109 | def test_empty_list(self): 110 | with self.assertRaises(ValueError): 111 | load([], top=pdb_file) 112 | 113 | if __name__ == "__main__": 114 | unittest.main() -------------------------------------------------------------------------------- /chainsaw/util/log.py: -------------------------------------------------------------------------------- 1 | # This file is part of PyEMMA. 2 | # 3 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 4 | # 5 | # PyEMMA is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU Lesser General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU Lesser General Public License 16 | # along with this program. If not, see . 17 | ''' 18 | Created on 15.10.2013 19 | 20 | @author: marscher 21 | ''' 22 | 23 | from __future__ import absolute_import 24 | 25 | import logging 26 | from logging.config import dictConfig 27 | import os.path 28 | import warnings 29 | 30 | 31 | __all__ = ['getLogger', 32 | ] 33 | 34 | 35 | class LoggingConfigurationError(RuntimeError): 36 | pass 37 | 38 | 39 | def setup_logging(config, D=None): 40 | """ set up the logging system with the configured (in chainsaw.cfg) logging config (logging.yml) 41 | @param config: instance of chainsaw.config module (wrapper) 42 | """ 43 | if not D: 44 | import yaml 45 | 46 | args = config.logging_config 47 | default = False 48 | 49 | if args.upper() == 'DEFAULT': 50 | default = True 51 | src = config.default_logging_file 52 | else: 53 | src = args 54 | 55 | # first try to read configured file 56 | try: 57 | with open(src) as f: 58 | D = yaml.load(f) 59 | except EnvironmentError as ee: 60 | # fall back to default 61 | if not default: 62 | try: 63 | with open(config.default_logging_file) as f2: 64 | D = yaml.load(f2) 65 | except EnvironmentError as ee2: 66 | raise LoggingConfigurationError('Could not read either configured nor ' 67 | 'default logging configuration!\n%s' % ee2) 68 | else: 69 | raise LoggingConfigurationError('could not handle default logging ' 70 | 'configuration file\n%s' % ee) 71 | 72 | if D is None: 73 | raise LoggingConfigurationError('Empty logging config! Try using default config by' 74 | ' setting logging_conf=DEFAULT in chainsaw.cfg') 75 | assert D 76 | 77 | # if the user has not explicitly disabled other loggers, we (contrary to Pythons 78 | # default value) do not want to override them. 79 | D.setdefault('disable_existing_loggers', False) 80 | 81 | # configure using the dict 82 | try: 83 | dictConfig(D) 84 | except ValueError as ve: 85 | # issue with file handler? 86 | if 'files' in str(ve) and 'rotating_files' in D['handlers']: 87 | print("cfg dir", config.cfg_dir) 88 | new_file = os.path.join(config.cfg_dir, 'chainsaw.log') 89 | warnings.warn("set logfile to %s, because there was" 90 | " an error writing to the desired one" % new_file) 91 | D['handlers']['rotating_files']['filename'] = new_file 92 | else: 93 | raise 94 | dictConfig(D) 95 | 96 | # get log file name of chainsaw root logger 97 | logger = logging.getLogger('chainsaw') 98 | log_files = [getattr(h, 'baseFilename', None) for h in logger.handlers] 99 | 100 | import atexit 101 | @atexit.register 102 | def clean_empty_log_files(): 103 | # gracefully shutdown logging system 104 | logging.shutdown() 105 | for f in log_files: 106 | if f is not None and os.path.exists(f): 107 | try: 108 | if os.stat(f).st_size == 0: 109 | os.remove(f) 110 | except OSError as o: 111 | print("during removal of empty logfiles there was a problem: ", o) 112 | 113 | -------------------------------------------------------------------------------- /chainsaw/clustering/uniform_time.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | from __future__ import absolute_import, division 21 | 22 | import math 23 | 24 | import numpy as np 25 | 26 | from .interface import AbstractClustering 27 | from chainsaw.util.annotators import fix_docs 28 | 29 | __author__ = 'noe' 30 | __all__ = ['UniformTimeClustering'] 31 | 32 | 33 | @fix_docs 34 | class UniformTimeClustering(AbstractClustering): 35 | r"""Uniform time clustering""" 36 | 37 | def __init__(self, n_clusters=2, metric='euclidean', stride=1, n_jobs=None, skip=0): 38 | """r 39 | Uniform time clustering 40 | 41 | Parameters 42 | ---------- 43 | n_clusters : int 44 | amount of desired cluster centers. When not specified (None), 45 | min(sqrt(N), 5000) is chosen as default value, 46 | where N denotes the number of data points 47 | metric : str 48 | metric to use during clustering ('euclidean', 'minRMSD') 49 | stride : int 50 | stride 51 | n_jobs : int or None, default None 52 | Number of threads to use during assignment of the data. 53 | If None, all available CPUs will be used. 54 | skip : int, default=0 55 | skip the first initial n frames per trajectory. 56 | """ 57 | super(UniformTimeClustering, self).__init__(metric=metric, n_jobs=n_jobs) 58 | self.set_params(n_clusters=n_clusters, metric=metric, stride=stride, skip=skip) 59 | 60 | def describe(self): 61 | return "[Uniform time clustering, k = %i, inp_dim=%i]" \ 62 | % (self.n_clusters, self.data_producer.dimension()) 63 | 64 | def _estimate(self, iterable, **kw): 65 | 66 | if self.n_clusters is None: 67 | traj_lengths = self.trajectory_lengths(stride=self.stride, skip=self.skip) 68 | total_length = sum(traj_lengths) 69 | self.n_clusters = min(int(math.sqrt(total_length)), 5000) 70 | self._logger.info("The number of cluster centers was not specified, " 71 | "using min(sqrt(N), 5000)=%s as n_clusters." % self.n_clusters) 72 | 73 | # initialize time counters 74 | T = iterable.n_frames_total(stride=self.stride, skip=self.skip) 75 | if self.n_clusters > T: 76 | self.n_clusters = T 77 | self._logger.info('Requested more clusters (k = %i' 78 | ' than there are total data points %i)' 79 | '. Will do clustering with k = %i' 80 | % (self.n_clusters, T, T)) 81 | 82 | # first data point in the middle of the time segment 83 | next_t = (T // self.n_clusters) // 2 84 | # cumsum of lenghts 85 | cumsum = np.cumsum(self.trajectory_lengths(skip=self.skip)) 86 | # distribution of integers, truncate if n_clusters is too large 87 | linspace = self.stride * np.arange(next_t, T - next_t + 1, (T - 2*next_t + 1) // self.n_clusters)[:self.n_clusters] 88 | # random access matrix 89 | ra_stride = np.array([UniformTimeClustering._idx_to_traj_idx(x, cumsum) for x in linspace]) 90 | with iterable.iterator(stride=ra_stride, return_trajindex=False, chunk=self.chunksize, skip=self.skip) as it: 91 | self.clustercenters = np.concatenate([X for X in it]) 92 | 93 | assert len(self.clustercenters) == self.n_clusters 94 | return self 95 | 96 | @staticmethod 97 | def _idx_to_traj_idx(idx, cumsum): 98 | prev_len = 0 99 | for trajIdx, length in enumerate(cumsum): 100 | if prev_len <= idx < length: 101 | return trajIdx, idx - prev_len 102 | prev_len = length 103 | raise ValueError("Requested index %s was out of bounds [0,%s)" % (idx, cumsum[-1])) 104 | -------------------------------------------------------------------------------- /chainsaw/tests/data/bpti_ca.pdb: -------------------------------------------------------------------------------- 1 | ATOM 2 CA ARG A 1 5.137 10.135 0.877 0.00 0.00 C1 C 2 | ATOM 28 CA PRO A 2 6.740 7.345 -1.220 0.00 0.00 C1 C 3 | ATOM 42 CA ASP A 3 4.355 4.870 -2.726 0.00 0.00 C1 C 4 | ATOM 54 CA PHE A 4 5.732 1.868 -0.668 0.00 0.00 C1 C 5 | ATOM 74 CA CYS A 5 5.395 3.903 2.491 0.00 0.00 C1 C 6 | ATOM 84 CA LEU A 6 1.634 4.834 1.750 0.00 0.00 C1 C 7 | ATOM 103 CA GLU A 7 0.330 1.215 1.642 0.00 0.00 C1 C 8 | ATOM 118 CA PRO A 8 -1.381 -1.343 4.103 0.00 0.00 C1 C 9 | ATOM 132 CA PRO A 9 0.785 -4.139 5.864 0.00 0.00 C1 C 10 | ATOM 146 CA TYR A 10 0.338 -7.965 5.239 0.00 0.00 C1 C 11 | ATOM 167 CA THR A 11 1.073 -10.832 7.801 0.00 0.00 C1 C 12 | ATOM 181 CA GLY A 12 1.165 -13.584 4.948 0.00 0.00 C1 C 13 | ATOM 188 CA PRO A 13 0.230 -17.400 5.025 0.00 0.00 C1 C 14 | ATOM 202 CA CYS A 14 3.079 -18.621 7.107 0.00 0.00 C1 C 15 | ATOM 212 CA LYS A 15 2.582 -19.937 10.866 0.00 0.00 C1 C 16 | ATOM 234 CA ALA A 16 4.792 -17.419 12.911 0.00 0.00 C1 C 17 | ATOM 244 CA ARG A 17 3.794 -14.113 14.750 0.00 0.00 C1 C 18 | ATOM 268 CA ILE A 18 7.191 -12.307 14.496 0.00 0.00 C1 C 19 | ATOM 287 CA ILE A 19 7.143 -8.504 14.984 0.00 0.00 C1 C 20 | ATOM 306 CA ARG A 20 8.006 -6.620 11.818 0.00 0.00 C1 C 21 | ATOM 330 CA TYR A 21 7.877 -2.855 11.290 0.00 0.00 C1 C 22 | ATOM 351 CA PHE A 22 5.979 -0.736 8.661 0.00 0.00 C1 C 23 | ATOM 371 CA TYR A 23 4.972 2.928 7.941 0.00 0.00 C1 C 24 | ATOM 392 CA ASN A 24 1.390 3.611 9.146 0.00 0.00 C1 C 25 | ATOM 406 CA ALA A 25 0.700 6.390 6.695 0.00 0.00 C1 C 26 | ATOM 416 CA LYS A 26 -2.664 7.139 8.539 0.00 0.00 C1 C 27 | ATOM 438 CA ALA A 27 -0.686 8.124 11.724 0.00 0.00 C1 C 28 | ATOM 448 CA GLY A 28 2.403 9.458 9.885 0.00 0.00 C1 C 29 | ATOM 455 CA LEU A 29 4.794 7.134 11.896 0.00 0.00 C1 C 30 | ATOM 474 CA CYS A 30 6.628 3.768 11.766 0.00 0.00 C1 C 31 | ATOM 484 CA GLN A 31 4.809 1.106 13.924 0.00 0.00 C1 C 32 | ATOM 501 CA THR A 32 4.558 -2.660 14.660 0.00 0.00 C1 C 33 | ATOM 515 CA PHE A 33 2.741 -5.558 13.036 0.00 0.00 C1 C 34 | ATOM 535 CA VAL A 34 2.912 -9.385 13.166 0.00 0.00 C1 C 35 | ATOM 551 CA TYR A 35 4.160 -11.476 10.334 0.00 0.00 C1 C 36 | ATOM 572 CA GLY A 36 4.181 -15.187 9.310 0.00 0.00 C1 C 37 | ATOM 579 CA GLY A 37 7.863 -15.378 8.107 0.00 0.00 C1 C 38 | ATOM 586 CA CYS A 38 7.170 -15.482 4.311 0.00 0.00 C1 C 39 | ATOM 596 CA ARG A 39 6.097 -13.526 1.228 0.00 0.00 C1 C 40 | ATOM 620 CA ALA A 40 6.717 -9.865 2.308 0.00 0.00 C1 C 41 | ATOM 630 CA LYS A 41 4.813 -6.980 0.893 0.00 0.00 C1 C 42 | ATOM 652 CA ARG A 42 6.897 -3.891 0.135 0.00 0.00 C1 C 43 | ATOM 676 CA ASN A 43 5.975 -2.121 3.481 0.00 0.00 C1 C 44 | ATOM 690 CA ASN A 44 7.760 -4.726 5.781 0.00 0.00 C1 C 45 | ATOM 704 CA PHE A 45 10.990 -3.890 7.611 0.00 0.00 C1 C 46 | ATOM 724 CA LYS A 46 13.280 -5.405 10.215 0.00 0.00 C1 C 47 | ATOM 746 CA SER A 47 13.349 -2.293 12.285 0.00 0.00 C1 C 48 | ATOM 757 CA ALA A 48 11.836 1.205 12.902 0.00 0.00 C1 C 49 | ATOM 767 CA GLU A 49 15.549 2.323 12.263 0.00 0.00 C1 C 50 | ATOM 782 CA ASP A 50 15.332 0.878 8.725 0.00 0.00 C1 C 51 | ATOM 794 CA CYS A 51 11.618 2.060 8.308 0.00 0.00 C1 C 52 | ATOM 804 CA MET A 52 12.708 5.603 9.261 0.00 0.00 C1 C 53 | ATOM 821 CA ARG A 53 15.884 5.306 7.040 0.00 0.00 C1 C 54 | ATOM 845 CA THR A 54 13.448 4.173 4.196 0.00 0.00 C1 C 55 | ATOM 859 CA CYS A 55 10.100 6.043 4.700 0.00 0.00 C1 C 56 | ATOM 869 CA GLY A 56 11.076 9.290 6.702 0.00 0.00 C1 C 57 | ATOM 876 CA GLY A 57 10.474 11.834 3.826 0.00 0.00 C1 C 58 | ATOM 883 CA ALA A 58 7.286 13.898 2.917 0.00 0.00 C1 C 59 | -------------------------------------------------------------------------------- /chainsaw/base/model.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import absolute_import 20 | import warnings 21 | 22 | from chainsaw.util.reflection import getargspec_no_self 23 | 24 | __author__ = 'noe' 25 | 26 | 27 | class Model(object): 28 | """ Base class for Chainsaws models 29 | 30 | This class is inspired by sklearn's BaseEstimator class. However, we define parameter names not by the 31 | current class' __init__ but have to announce them. This allows us to also remember the parameters of model 32 | superclasses. This class can be mixed with Chainsaw and sklearn Estimators. 33 | 34 | """ 35 | 36 | def _get_model_param_names(self): 37 | r"""Get parameter names for the model""" 38 | # fetch model parameters 39 | if hasattr(self, 'set_model_params'): 40 | set_model_param_method = getattr(self, 'set_model_params') 41 | # introspect the constructor arguments to find the model parameters 42 | # to represent 43 | args, varargs, kw, default = getargspec_no_self(set_model_param_method) 44 | if varargs is not None: 45 | raise RuntimeError("Models should always specify their parameters in the signature" 46 | " of their set_model_params (no varargs). %s doesn't follow this convention." 47 | % (self, )) 48 | # Remove 'self' 49 | # XXX: This is going to fail if the init is a staticmethod, but 50 | # who would do this? 51 | args.pop(0) 52 | args.sort() 53 | return args 54 | else: 55 | # No parameters known 56 | return [] 57 | 58 | def update_model_params(self, **params): 59 | r"""Update given model parameter if they are set to specific values""" 60 | for key, value in list(params.items()): 61 | if not hasattr(self, key): 62 | setattr(self, key, value) # set parameter for the first time. 63 | elif getattr(self, key) is None: 64 | setattr(self, key, value) # update because this parameter is still None. 65 | elif value is not None: 66 | setattr(self, key, value) # only overwrite if set to a specific value (None does not overwrite). 67 | 68 | def get_model_params(self, deep=True): 69 | r"""Get parameters for this model. 70 | 71 | Parameters 72 | ---------- 73 | deep: boolean, optional 74 | If True, will return the parameters for this estimator and 75 | contained subobjects that are estimators. 76 | Returns 77 | ------- 78 | params : mapping of string to any 79 | Parameter names mapped to their values. 80 | """ 81 | out = dict() 82 | for key in self._get_model_param_names(): 83 | # We need deprecation warnings to always be on in order to 84 | # catch deprecated param values. 85 | # This is set in utils/__init__.py but it gets overwritten 86 | # when running under python3 somehow. 87 | warnings.simplefilter("always", DeprecationWarning) 88 | try: 89 | with warnings.catch_warnings(record=True) as w: 90 | value = getattr(self, key, None) 91 | if len(w) and w[0].category == DeprecationWarning: 92 | # if the parameter is deprecated, don't show it 93 | continue 94 | finally: 95 | warnings.filters.pop(0) 96 | 97 | # XXX: should we rather test if instance of estimator? 98 | if deep and hasattr(value, 'get_params'): 99 | deep_items = list(value.get_params().items()) 100 | out.update((key + '__' + k, val) for k, val in deep_items) 101 | out[key] = value 102 | return out 103 | -------------------------------------------------------------------------------- /chainsaw/tests/test_cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | from glob import glob 5 | 6 | import numpy as np 7 | from chainsaw import config 8 | 9 | from chainsaw.data.cache import Cache 10 | import chainsaw 11 | 12 | from logging import getLogger 13 | 14 | class ProfilerCase(unittest.TestCase): 15 | def setUp(self): 16 | import cProfile 17 | self.pr = cProfile.Profile() 18 | self.pr.enable() 19 | 20 | def tearDown(self): 21 | self.pr.dump_stats(self._testMethodName) 22 | 23 | 24 | class TestCache(unittest.TestCase): 25 | 26 | @classmethod 27 | def setUpClass(cls): 28 | config.use_trajectory_lengths_cache = False 29 | cls.test_dir = tempfile.mkdtemp(prefix="test_cache_") 30 | 31 | cls.length = 1000 32 | cls.dim = 3 33 | data = [np.random.random((cls.length, cls.dim)) for _ in range(10)] 34 | 35 | for i, x in enumerate(data): 36 | np.save(os.path.join(cls.test_dir, "{}.npy".format(i)), x) 37 | 38 | cls.files = glob(cls.test_dir + "/*.npy") 39 | 40 | def setUp(self): 41 | super(TestCache, self).setUp() 42 | self.tmp_cache_dir = tempfile.mkdtemp(dir=self.test_dir) 43 | config.cache_dir = self.tmp_cache_dir 44 | self.logger = getLogger("chainsaw.test.%s"%self.id()) 45 | 46 | @classmethod 47 | def tearDownClass(cls): 48 | import shutil 49 | shutil.rmtree(cls.test_dir, ignore_errors=False) 50 | 51 | def test_cache_hits(self): 52 | src = chainsaw.source(self.files, chunk_size=1000) 53 | src.describe() 54 | data = src.get_output() 55 | cache = Cache(src) 56 | self.assertEqual(cache.data.misses, 0) 57 | out = cache.get_output() 58 | self.assertEqual(cache.data.misses, len(self.files)) 59 | 60 | for actual, desired in zip(out, data): 61 | np.testing.assert_allclose(actual, desired, atol=1e-15, rtol=1e-7) 62 | 63 | cache.get_output() 64 | self.assertEqual(len(cache.data.hits), len(self.files)) 65 | 66 | #self.assertIn("items={}".format(len(cache)), repr(cache)) 67 | 68 | def test_get_output(self): 69 | src = chainsaw.source(self.files, chunk_size=0) 70 | dim = 1 71 | stride = 2 72 | skip = 3 73 | desired = src.get_output(stride=stride, dimensions=dim, skip=skip) 74 | 75 | cache = Cache(src) 76 | actual = cache.get_output(stride=stride, dimensions=dim, skip=skip) 77 | np.testing.assert_allclose(actual, desired) 78 | 79 | def test_tica_cached_input(self): 80 | src = chainsaw.source(self.files, chunk_size=0) 81 | cache = Cache(src) 82 | 83 | tica_cache_inp = chainsaw.tica(cache) 84 | 85 | tica_without_cache = chainsaw.tica(src) 86 | 87 | np.testing.assert_allclose(tica_cache_inp.cov, tica_without_cache.cov, atol=1e-10) 88 | np.testing.assert_allclose(tica_cache_inp.cov_tau, tica_without_cache.cov_tau, atol=1e-9) 89 | 90 | np.testing.assert_allclose(tica_cache_inp.eigenvalues, tica_without_cache.eigenvalues, atol=1e-7) 91 | np.testing.assert_allclose(np.abs(tica_cache_inp.eigenvectors), 92 | np.abs(tica_without_cache.eigenvectors), atol=1e-6) 93 | 94 | def test_tica_cached_output(self): 95 | src = chainsaw.source(self.files, chunk_size=0) 96 | tica = chainsaw.tica(src, dim=2) 97 | 98 | tica_output = tica.get_output() 99 | cache = Cache(tica) 100 | 101 | np.testing.assert_allclose(cache.get_output(), tica_output) 102 | 103 | def test_cache_switch_cache_file(self): 104 | src = chainsaw.source(self.files, chunk_size=0) 105 | t = chainsaw.tica(src, dim=2) 106 | cache = Cache(t) 107 | 108 | def test_with_feature_reader_switch_cache_file(self): 109 | import pkg_resources 110 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep 111 | pdbfile = os.path.join(path, 'bpti_ca.pdb') 112 | trajfiles = os.path.join(path, 'bpti_mini.xtc') 113 | 114 | reader = chainsaw.source(trajfiles, top=pdbfile) 115 | reader.featurizer.add_selection([0,1, 2]) 116 | 117 | cache = Cache(reader) 118 | name_of_cache = cache.current_cache_file_name 119 | 120 | reader.featurizer.add_selection([5, 8, 9]) 121 | new_name_of_cache = cache.current_cache_file_name 122 | 123 | self.assertNotEqual(name_of_cache, new_name_of_cache) 124 | 125 | # remove 2nd feature and check we've got the old name back. 126 | reader.featurizer.active_features.pop() 127 | self.assertEqual(cache.current_cache_file_name, name_of_cache) 128 | 129 | # add a new file and ensure we still got the same cache file 130 | reader.filenames.append(os.path.join(path, 'bpti_001-033.xtc')) 131 | self.assertEqual(cache.current_cache_file_name, name_of_cache) 132 | 133 | 134 | if __name__ == '__main__': 135 | unittest.main() 136 | -------------------------------------------------------------------------------- /chainsaw/tests/test_featurereader_and_tica.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Test feature reader and Tica with a set of cosine time series. 22 | @author: Fabian Paul 23 | ''' 24 | 25 | from __future__ import absolute_import 26 | from __future__ import print_function 27 | 28 | import os 29 | import tempfile 30 | import unittest 31 | from logging import getLogger 32 | 33 | import mdtraj 34 | import numpy as np 35 | from chainsaw import api 36 | from chainsaw.data.md.feature_reader import FeatureReader 37 | from six.moves import range 38 | 39 | log = getLogger('chainsaw.'+'TestFeatureReaderAndTICA') 40 | 41 | 42 | class TestFeatureReaderAndTICA(unittest.TestCase): 43 | @classmethod 44 | def setUpClass(cls): 45 | cls.dim = 9 # dimension (must be divisible by 3) 46 | N = 50000 # length of single trajectory # 500000 47 | N_trajs = 10 # number of trajectories 48 | 49 | cls.w = 2.0*np.pi*1000.0/N # have 1000 cycles in each trajectory 50 | 51 | # get random amplitudes and phases 52 | cls.A = np.random.randn(cls.dim) 53 | cls.phi = np.random.random_sample((cls.dim,))*np.pi*2.0 54 | mean = np.random.randn(cls.dim) 55 | 56 | # create topology file 57 | cls.temppdb = tempfile.mktemp('.pdb') 58 | with open(cls.temppdb, 'w') as f: 59 | for i in range(cls.dim//3): 60 | print(('ATOM %5d C ACE A 1 28.490 31.600 33.379 0.00 1.00' % i), file=f) 61 | 62 | t = np.arange(0, N) 63 | t_total = 0 64 | cls.trajnames = [] # list of xtc file names 65 | for i in range(N_trajs): 66 | # set up data 67 | data = cls.A*np.cos((cls.w*(t+t_total))[:, np.newaxis]+cls.phi) + mean 68 | xyz = data.reshape((N, cls.dim//3, 3)) 69 | # create trajectory file 70 | traj = mdtraj.load(cls.temppdb) 71 | traj.xyz = xyz 72 | traj.time = t 73 | tempfname = tempfile.mktemp('.xtc') 74 | traj.save(tempfname) 75 | cls.trajnames.append(tempfname) 76 | t_total += N 77 | 78 | @classmethod 79 | def tearDownClass(cls): 80 | for fname in cls.trajnames: 81 | os.unlink(fname) 82 | os.unlink(cls.temppdb) 83 | super(TestFeatureReaderAndTICA, cls).tearDownClass() 84 | 85 | def test_covariances_and_eigenvalues(self): 86 | reader = FeatureReader(self.trajnames, self.temppdb, chunksize=10000) 87 | for lag in [1, 11, 101, 1001, 2001]: # avoid cos(w*tau)==0 88 | trans = api.tica(data=reader, dim=self.dim, lag=lag) 89 | log.info('number of trajectories reported by tica %d' % trans.number_of_trajectories()) 90 | log.info('tau = %d corresponds to a number of %f cycles' % (lag, self.w*lag/(2.0*np.pi))) 91 | trans.parametrize() 92 | 93 | # analytical solution for C_ij(lag) is 0.5*A[i]*A[j]*cos(phi[i]-phi[j])*cos(w*lag) 94 | ana_cov = 0.5*self.A[:, np.newaxis]*self.A*np.cos(self.phi[:, np.newaxis]-self.phi) 95 | ana_cov_tau = ana_cov*np.cos(self.w*lag) 96 | 97 | self.assertTrue(np.allclose(ana_cov, trans.cov, atol=1.E-3)) 98 | self.assertTrue(np.allclose(ana_cov_tau, trans.cov_tau, atol=1.E-3)) 99 | log.info('max. eigenvalue: %f' % np.max(trans.eigenvalues)) 100 | self.assertTrue(np.all(trans.eigenvalues <= 1.0)) 101 | 102 | def test_partial_fit(self): 103 | reader = FeatureReader(self.trajnames, self.temppdb, chunksize=10000) 104 | output = reader.get_output() 105 | params = {'dim': self.dim, 'lag': 1001} 106 | ref = api.tica(reader, **params) 107 | partial = api.tica(**params) 108 | 109 | for traj in output: 110 | partial.partial_fit(traj) 111 | 112 | np.testing.assert_allclose(partial.eigenvalues, ref.eigenvalues) 113 | # only compare first two eigenvectors, because we only have two metastable processes 114 | np.testing.assert_allclose(np.abs(partial.eigenvectors[:2]), 115 | np.abs(ref.eigenvectors[:2]), rtol=1e-3, atol=1e-3) 116 | 117 | if __name__ == "__main__": 118 | unittest.main() 119 | -------------------------------------------------------------------------------- /chainsaw/tests/test_stride.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | 21 | from __future__ import print_function 22 | 23 | from __future__ import absolute_import 24 | import unittest 25 | import os 26 | import tempfile 27 | import numpy as np 28 | import mdtraj 29 | import chainsaw as coor 30 | from six.moves import range 31 | from six.moves import zip 32 | 33 | class TestStride(unittest.TestCase): 34 | @classmethod 35 | def setUpClass(cls): 36 | cls.dim = 3 # dimension (must be divisible by 3) 37 | N_trajs = 10 # number of trajectories 38 | 39 | # create topology file 40 | cls.temppdb = tempfile.mktemp('.pdb') 41 | with open(cls.temppdb, 'w') as f: 42 | for i in range(cls.dim//3): 43 | print(('ATOM %5d C ACE A 1 28.490 31.600 33.379 0.00 1.00' % i), file=f) 44 | 45 | cls.trajnames = [] # list of xtc file names 46 | cls.data = [] 47 | for i in range(N_trajs): 48 | # set up data 49 | N = int(np.random.rand()*1000+1000) 50 | xyz = np.random.randn(N, cls.dim//3, 3).astype(np.float32) 51 | cls.data.append(xyz) 52 | t = np.arange(0, N) 53 | # create trajectory file 54 | traj = mdtraj.load(cls.temppdb) 55 | traj.xyz = xyz 56 | traj.time = t 57 | tempfname = tempfile.mktemp('.xtc') 58 | traj.save(tempfname) 59 | cls.trajnames.append(tempfname) 60 | 61 | def test_length_and_content_feature_reader_and_TICA(self): 62 | for stride in range(1, 100, 23): 63 | r = coor.source(self.trajnames, top=self.temppdb) 64 | t = coor.tica(data=r, lag=2, dim=2, force_eigenvalues_le_one=True) 65 | # t.data_producer = r 66 | t.parametrize() 67 | 68 | # subsample data 69 | out_tica = t.get_output(stride=stride) 70 | out_reader = r.get_output(stride=stride) 71 | 72 | # get length in different ways 73 | len_tica = [x.shape[0] for x in out_tica] 74 | len_reader = [x.shape[0] for x in out_reader] 75 | len_trajs = t.trajectory_lengths(stride=stride) 76 | len_ref = [(x.shape[0]-1)//stride+1 for x in self.data] 77 | # print 'len_ref', len_ref 78 | 79 | # compare length 80 | np.testing.assert_equal(len_trajs, len_ref) 81 | self.assertTrue(len_ref == len_tica) 82 | self.assertTrue(len_ref == len_reader) 83 | 84 | # compare content (reader) 85 | for ref_data, test_data in zip(self.data, out_reader): 86 | ref_data_reshaped = ref_data.reshape((ref_data.shape[0], ref_data.shape[1]*3)) 87 | self.assertTrue(np.allclose(ref_data_reshaped[::stride, :], test_data, atol=1E-3)) 88 | 89 | def test_content_data_in_memory(self): 90 | # prepare test data 91 | N_trajs = 10 92 | d = [] 93 | for _ in range(N_trajs): 94 | N = int(np.random.rand()*1000+10) 95 | d.append(np.random.randn(N, 10).astype(np.float32)) 96 | 97 | # read data 98 | reader = coor.source(d) 99 | 100 | # compare 101 | for stride in range(1, 10, 3): 102 | out_reader = reader.get_output(stride=stride) 103 | for ref_data, test_data in zip(d, out_reader): 104 | self.assertTrue(np.all(ref_data[::stride] == test_data)) # here we can test exact equality 105 | 106 | def test_parametrize_with_stride(self): 107 | for stride in range(1, 100, 23): 108 | r = coor.source(self.trajnames, top=self.temppdb) 109 | tau = 5 110 | try: 111 | t = coor.tica(r, lag=tau, stride=stride, dim=2, force_eigenvalues_le_one=True) 112 | # force_eigenvalues_le_one=True enables an internal consistency check in TICA 113 | t.parametrize(stride=stride) 114 | self.assertTrue(np.all(t.eigenvalues <= 1.0+1.E-12)) 115 | except RuntimeError: 116 | assert tau % stride != 0 117 | 118 | @classmethod 119 | def tearDownClass(cls): 120 | for fname in cls.trajnames: 121 | os.unlink(fname) 122 | os.unlink(cls.temppdb) 123 | super(TestStride, cls).tearDownClass() 124 | 125 | if __name__ == "__main__": 126 | unittest.main() -------------------------------------------------------------------------------- /chainsaw/tests/test_regspace.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Created on 26.01.2015 22 | 23 | @author: marscher 24 | ''' 25 | 26 | from __future__ import absolute_import 27 | import itertools 28 | import unittest 29 | 30 | from chainsaw.clustering.regspace import RegularSpaceClustering 31 | from chainsaw.data.data_in_memory import DataInMemory 32 | from chainsaw.api import cluster_regspace 33 | 34 | import numpy as np 35 | from chainsaw.util import types 36 | 37 | 38 | class RandomDataSource(DataInMemory): 39 | 40 | def __init__(self, a=None, b=None, chunksize=100, n_samples=1000, dim=3): 41 | """ 42 | creates random values in interval [a,b] 43 | """ 44 | data = np.random.random((n_samples, dim)) 45 | if a is not None and b is not None: 46 | data *= (b - a) 47 | data += a 48 | super(RandomDataSource, self).__init__(data, chunksize=chunksize) 49 | 50 | 51 | class TestRegSpaceClustering(unittest.TestCase): 52 | 53 | def setUp(self): 54 | self.dmin = 0.3 55 | self.clustering = RegularSpaceClustering(dmin=self.dmin) 56 | self.clustering.data_producer = RandomDataSource() 57 | 58 | def test_algorithm(self): 59 | self.clustering.parametrize() 60 | 61 | # correct type of dtrajs 62 | assert types.is_int_vector(self.clustering.dtrajs[0]) 63 | 64 | # assert distance for each centroid is at least dmin 65 | for c in itertools.combinations(self.clustering.clustercenters, 2): 66 | if np.allclose(c[0], c[1]): # skip equal pairs 67 | continue 68 | 69 | dist = np.linalg.norm(c[0] - c[1], 2) 70 | 71 | self.assertGreaterEqual(dist, self.dmin, 72 | "centroid pair\n%s\n%s\n has smaller" 73 | " distance than dmin(%f): %f" 74 | % (c[0], c[1], self.dmin, dist)) 75 | 76 | def test_assignment(self): 77 | self.clustering.parametrize() 78 | 79 | assert len(self.clustering.clustercenters) > 1 80 | 81 | # num states == num _clustercenters? 82 | self.assertEqual(len(np.unique(self.clustering.dtrajs)), len( 83 | self.clustering.clustercenters), "number of unique states in dtrajs" 84 | " should be equal.") 85 | 86 | data_to_cluster = np.random.random((1000, 3)) 87 | 88 | self.clustering.assign(data_to_cluster, stride=1) 89 | 90 | def test_spread_data(self): 91 | self.clustering.data_producer = RandomDataSource(a=-2, b=2) 92 | self.clustering.dmin = 2 93 | self.clustering.parametrize() 94 | 95 | def test1d_data(self): 96 | data = np.random.random(100) 97 | cluster_regspace(data, dmin=0.3) 98 | 99 | def test_non_existent_metric(self): 100 | self.clustering.data_producer = RandomDataSource(a=-2, b=2) 101 | self.clustering.dmin = 2 102 | self.clustering.metric = "non_existent_metric" 103 | with self.assertRaises(ValueError): 104 | self.clustering.parametrize() 105 | 106 | def test_minRMSD_metric(self): 107 | self.clustering.data_producer = RandomDataSource(a=-2, b=2) 108 | self.clustering.dmin = 2 109 | self.clustering.metric = "minRMSD" 110 | self.clustering.parametrize() 111 | 112 | data_to_cluster = np.random.random((1000, 3)) 113 | 114 | self.clustering.assign(data_to_cluster, stride=1) 115 | 116 | def test_too_small_dmin_should_warn(self): 117 | self.clustering.dmin = 1e-8 118 | max_centers = 50 119 | self.clustering.max_centers = max_centers 120 | import warnings 121 | with warnings.catch_warnings(record=True) as w: 122 | # Cause all warnings to always be triggered. 123 | warnings.simplefilter("always") 124 | # Trigger a warning. 125 | self.clustering.estimate(self.clustering.data_producer) 126 | assert w 127 | assert len(w) == 1 128 | 129 | assert len(self.clustering.clustercenters) == max_centers 130 | 131 | # assign data 132 | out = self.clustering.get_output() 133 | assert len(out) == self.clustering.number_of_trajectories() 134 | assert len(out[0]) == self.clustering.trajectory_lengths()[0] 135 | 136 | if __name__ == "__main__": 137 | unittest.main() 138 | -------------------------------------------------------------------------------- /chainsaw/tests/test_discretizer.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | ''' 20 | Created on 19.01.2015 21 | 22 | @author: marscher 23 | ''' 24 | 25 | from __future__ import absolute_import 26 | import os 27 | import tempfile 28 | import unittest 29 | import mdtraj 30 | import numpy as np 31 | 32 | from mdtraj.core.trajectory import Trajectory 33 | from mdtraj.core.element import hydrogen, oxygen 34 | from mdtraj.core.topology import Topology 35 | 36 | from chainsaw.clustering.uniform_time import UniformTimeClustering 37 | from chainsaw.pipelines import Discretizer 38 | from chainsaw.data.data_in_memory import DataInMemory 39 | from chainsaw.api import cluster_kmeans, pca, source 40 | from six.moves import range 41 | 42 | 43 | def create_water_topology_on_disc(n): 44 | topfile = tempfile.mktemp('.pdb') 45 | top = Topology() 46 | chain = top.add_chain() 47 | 48 | for i in range(n): 49 | res = top.add_residue('r%i' % i, chain) 50 | h1 = top.add_atom('H', hydrogen, res) 51 | o = top.add_atom('O', oxygen, res) 52 | h2 = top.add_atom('H', hydrogen, res) 53 | top.add_bond(h1, o) 54 | top.add_bond(h2, o) 55 | 56 | xyz = np.zeros((n * 3, 3)) 57 | Trajectory(xyz, top).save_pdb(topfile) 58 | return topfile 59 | 60 | 61 | def create_traj_on_disc(topfile, n_frames, n_atoms): 62 | fn = tempfile.mktemp('.xtc') 63 | xyz = np.random.random((n_frames, n_atoms, 3)) 64 | t = mdtraj.load(topfile) 65 | t.xyz = xyz 66 | t.time = np.arange(n_frames) 67 | t.save(fn) 68 | return fn 69 | 70 | 71 | class TestDiscretizer(unittest.TestCase): 72 | 73 | @classmethod 74 | def setUpClass(cls): 75 | c = super(TestDiscretizer, cls).setUpClass() 76 | # create a fake trajectory which has 2 atoms and coordinates are just a range 77 | # over all frames. 78 | cls.n_frames = 1000 79 | cls.n_residues = 30 80 | cls.topfile = create_water_topology_on_disc(cls.n_residues) 81 | 82 | # create some trajectories 83 | t1 = create_traj_on_disc( 84 | cls.topfile, cls.n_frames, cls.n_residues * 3) 85 | 86 | t2 = create_traj_on_disc( 87 | cls.topfile, cls.n_frames, cls.n_residues * 3) 88 | 89 | cls.trajfiles = [t1, t2] 90 | 91 | cls.dest_dir = tempfile.mkdtemp() 92 | 93 | return c 94 | 95 | @classmethod 96 | def tearDownClass(cls): 97 | """delete temporary files""" 98 | os.unlink(cls.topfile) 99 | for f in cls.trajfiles: 100 | os.unlink(f) 101 | 102 | import shutil 103 | shutil.rmtree(cls.dest_dir, ignore_errors=True) 104 | 105 | def test(self): 106 | reader = source(self.trajfiles, top=self.topfile) 107 | pcat = pca(dim=2) 108 | 109 | n_clusters = 2 110 | clustering = UniformTimeClustering(n_clusters=n_clusters) 111 | 112 | D = Discretizer(reader, transform=pcat, cluster=clustering) 113 | D.parametrize() 114 | 115 | self.assertEqual(len(D.dtrajs), len(self.trajfiles)) 116 | 117 | for dtraj in clustering.dtrajs: 118 | unique = np.unique(dtraj) 119 | self.assertEqual(unique.shape[0], n_clusters) 120 | 121 | def test_with_data_in_mem(self): 122 | import chainsaw as api 123 | 124 | data = [np.random.random((100, 50)), 125 | np.random.random((103, 50)), 126 | np.random.random((33, 50))] 127 | reader = source(data) 128 | assert isinstance(reader, DataInMemory) 129 | 130 | tpca = api.pca(dim=2) 131 | 132 | n_centers = 10 133 | km = api.cluster_kmeans(k=n_centers) 134 | 135 | disc = api.discretizer(reader, tpca, km) 136 | disc.parametrize() 137 | 138 | dtrajs = disc.dtrajs 139 | for dtraj in dtrajs: 140 | n_states = np.max((np.unique(dtraj))) 141 | self.assertGreaterEqual(n_centers - 1, n_states, 142 | "dtraj has more states than cluster centers") 143 | 144 | def test_save_dtrajs(self): 145 | reader = source(self.trajfiles, top=self.topfile) 146 | cluster = cluster_kmeans(k=2) 147 | d = Discretizer(reader, cluster=cluster) 148 | d.parametrize() 149 | d.save_dtrajs(output_dir=self.dest_dir) 150 | dtrajs = os.listdir(self.dest_dir) 151 | 152 | 153 | if __name__ == "__main__": 154 | unittest.main() 155 | -------------------------------------------------------------------------------- /chainsaw/acf.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | 21 | from __future__ import absolute_import, print_function 22 | import numpy as np 23 | import sys 24 | from six.moves import range 25 | 26 | __author__ = 'Fabian Paul' 27 | __all__ = ['acf'] 28 | 29 | 30 | def acf(trajs, stride=1, max_lag=None, subtract_mean=True, normalize=True, mean=None): 31 | '''Computes the (combined) autocorrelation function of multiple trajectories. 32 | 33 | Parameters 34 | ---------- 35 | trajs : list of (*,N) ndarrays 36 | the observable trajectories, N is the number of observables 37 | stride : int (default = 1) 38 | only take every n'th frame from trajs 39 | max_lag : int (default = maximum trajectory length / stride) 40 | only compute acf up to this lag time 41 | subtract_mean : bool (default = True) 42 | subtract trajectory mean before computing acfs 43 | normalize : bool (default = True) 44 | divide acf be the variance such that acf[0,:]==1 45 | mean : (N) ndarray (optional) 46 | if subtract_mean is True, you can give the trajectory mean 47 | so this functions doesn't have to compute it again 48 | 49 | Returns 50 | ------- 51 | acf : (max_lag,N) ndarray 52 | autocorrelation functions for all observables 53 | 54 | Note 55 | ---- 56 | The computation uses FFT (with zero-padding) and is done im memory (RAM). 57 | ''' 58 | if not isinstance(trajs, list): 59 | trajs = [trajs] 60 | 61 | mytrajs = [None] * len(trajs) 62 | for i in range(len(trajs)): 63 | if trajs[i].ndim == 1: 64 | mytrajs[i] = trajs[i].reshape((trajs[i].shape[0], 1)) 65 | elif trajs[i].ndim == 2: 66 | mytrajs[i] = trajs[i] 67 | else: 68 | raise Exception( 69 | 'Unexpected number of dimensions in trajectory number %d' % i) 70 | trajs = mytrajs 71 | 72 | assert stride > 0, 'stride must be > 0' 73 | assert max_lag is None or max_lag > 0, 'max_lag must be > 0' 74 | 75 | if subtract_mean and mean is None: 76 | # compute mean over all trajectories 77 | mean = trajs[0].sum(axis=0) 78 | n_samples = trajs[0].shape[0] 79 | for i, traj in enumerate(trajs[1:]): 80 | if traj.shape[1] != mean.shape[0]: 81 | raise Exception(('number of order parameters in trajectory number %d differs' + 82 | 'from the number found in previous trajectories.') % (i + 1)) 83 | mean += traj.sum(axis=0) 84 | n_samples += traj.shape[0] 85 | mean /= n_samples 86 | 87 | res = np.array([[]]) 88 | # number of samples for every tau 89 | N = np.array([]) 90 | 91 | for i, traj in enumerate(trajs): 92 | data = traj[::stride] 93 | if subtract_mean: 94 | data -= mean 95 | # calc acfs 96 | l = data.shape[0] 97 | fft = np.fft.fft(data, n=2 ** int(np.ceil(np.log2(l * 2 - 1))), axis=0) 98 | acftraj = np.fft.ifft(fft * np.conjugate(fft), axis=0).real 99 | # throw away acf data for long lag times (and negative lag times) 100 | if max_lag and max_lag < l: 101 | acftraj = acftraj[:max_lag, :] 102 | else: 103 | acftraj = acftraj[:l, :] 104 | if max_lag: 105 | sys.stderr.write( 106 | 'Warning: trajectory number %d is shorter than maximum lag.\n' % i) 107 | # find number of samples used for every lag 108 | Ntraj = np.linspace(l, l - acftraj.shape[0] + 1, acftraj.shape[0]) 109 | # adapt shape of acf: resize temporal dimension, additionally set 110 | # number of order parameters of acf in the first step 111 | if res.shape[1] < acftraj.shape[1] and res.shape[1] > 0: 112 | raise Exception(('number of order parameters in trajectory number %d differs ' + 113 | 'from the number found in previous trajectories.') % i) 114 | if res.shape[1] < acftraj.shape[1] or res.shape[0] < acftraj.shape[0]: 115 | res = np.resize(res, acftraj.shape) 116 | N = np.resize(N, acftraj.shape[0]) 117 | # update acf and number of samples 118 | res[0:acftraj.shape[0], :] += acftraj 119 | N[0:acftraj.shape[0]] += Ntraj 120 | 121 | # divide by number of samples 122 | res = np.transpose(np.transpose(res) / N) 123 | 124 | # normalize acfs 125 | if normalize: 126 | res /= res[0, :].copy() 127 | 128 | return res -------------------------------------------------------------------------------- /chainsaw/util/units.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | 21 | from __future__ import absolute_import 22 | __author__ = 'noe' 23 | 24 | import numpy as np 25 | import math 26 | 27 | class TimeUnit(object): 28 | 29 | _UNIT_STEP = -1 30 | _UNIT_FS = 0 31 | _UNIT_PS = 1 32 | _UNIT_NS = 2 33 | _UNIT_US = 3 34 | _UNIT_MS = 4 35 | _UNIT_S = 5 36 | _unit_names = ['fs','ps','ns','us','ms','s'] 37 | 38 | def __init__(self, unit = '1 step'): 39 | """ 40 | Initializes the time unit object 41 | 42 | Parameters 43 | ---------- 44 | unit : str 45 | Description of a physical time unit. By default '1 step', i.e. there is no physical time unit. 46 | Specify by a number, whitespace and unit. Permitted units are (* is an arbitrary string): 47 | 'fs', 'femtosecond*' 48 | 'ps', 'picosecond*' 49 | 'ns', 'nanosecond*' 50 | 'us', 'microsecond*' 51 | 'ms', 'millisecond*' 52 | 's', 'second*' 53 | 54 | """ 55 | if isinstance(unit, TimeUnit): # copy constructor 56 | self._factor = unit._factor 57 | self._unit = unit._unit 58 | else: # construct from string 59 | lunit = unit.lower() 60 | words = lunit.split(' ') 61 | 62 | if len(words) == 1: 63 | self._factor = 1.0 64 | unitstring = words[0] 65 | elif len(words) == 2: 66 | self._factor = float(words[0]) 67 | unitstring = words[1] 68 | else: 69 | raise ValueError('Illegal input string: '+str(unit)) 70 | 71 | if unitstring == 'step': 72 | self._unit = self._UNIT_STEP 73 | elif unitstring == 'fs' or unitstring.startswith('femtosecond'): 74 | self._unit = self._UNIT_FS 75 | elif unitstring == 'ps' or unitstring.startswith('picosecond'): 76 | self._unit = self._UNIT_PS 77 | elif unitstring == 'ns' or unitstring.startswith('nanosecond'): 78 | self._unit = self._UNIT_NS 79 | elif unitstring == 'us' or unitstring.startswith('microsecond'): 80 | self._unit = self._UNIT_US 81 | elif unitstring == 'ms' or unitstring.startswith('millisecond'): 82 | self._unit = self._UNIT_MS 83 | elif unitstring == 's' or unitstring.startswith('second'): 84 | self._unit = self._UNIT_S 85 | else: 86 | raise ValueError('Time unit is not understood: '+unit) 87 | 88 | def __str__(self): 89 | if self._unit == -1: 90 | return str(self._factor)+' step' 91 | else: 92 | return str(self._factor)+' '+self._unit_names[self._unit] 93 | 94 | @property 95 | def dt(self): 96 | return self._factor 97 | 98 | @property 99 | def unit(self): 100 | return self._unit 101 | 102 | def get_scaled(self, factor): 103 | """ Get a new time unit, scaled by the given factor """ 104 | import copy 105 | res = copy.deepcopy(self) 106 | res._factor *= factor 107 | return res 108 | 109 | def rescale_around1(self, times): 110 | """ 111 | Suggests a rescaling factor and new physical time unit to balance the given time multiples around 1. 112 | 113 | Parameters 114 | ---------- 115 | times : float array 116 | array of times in multiple of the present elementary unit 117 | 118 | """ 119 | if self._unit == self._UNIT_STEP: 120 | return times, 'step' # nothing to do 121 | 122 | m = np.mean(times) 123 | mult = 1.0 124 | cur_unit = self._unit 125 | 126 | # numbers are too small. Making them larger and reducing the unit: 127 | if (m < 0.001): 128 | while mult*m < 0.001 and cur_unit >= 0: 129 | mult *= 1000 130 | cur_unit -= 1 131 | return mult*times, self._unit_names[cur_unit] 132 | 133 | # numbers are too large. Making them smaller and increasing the unit: 134 | if (m > 1000): 135 | while mult*m > 1000 and cur_unit <= 5: 136 | mult /= 1000 137 | cur_unit += 1 138 | return mult*times, self._unit_names[cur_unit] 139 | 140 | # nothing to do 141 | return times, self._unit 142 | 143 | def bytes_to_string(num, suffix='B'): 144 | """ 145 | Returns the size of num (bytes) in a human readable form up to Yottabytes (YB). 146 | :param num: The size of interest in bytes. 147 | :param suffix: A suffix, default 'B' for 'bytes'. 148 | :return: a human readable representation of a size in bytes 149 | """ 150 | extensions = ["%s%s" % (x, suffix) for x in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']] 151 | if num == 0: 152 | return "0%s" % extensions[0] 153 | else: 154 | n_bytes = float(abs(num)) 155 | place = int(math.floor(math.log(n_bytes, 1024))) 156 | return "%.1f%s" % (np.sign(num) * (n_bytes / pow(1024, place)), extensions[place]) 157 | -------------------------------------------------------------------------------- /chainsaw/tests/test_coordinates_iterator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | 4 | from chainsaw.data import DataInMemory 5 | from chainsaw.util.files import TemporaryDirectory 6 | import os 7 | from glob import glob 8 | 9 | 10 | class TestCoordinatesIterator(unittest.TestCase): 11 | 12 | @classmethod 13 | def setUpClass(cls): 14 | cls.d = [np.random.random((100, 3)) for _ in range(3)] 15 | 16 | def test_current_trajindex(self): 17 | r = DataInMemory(self.d) 18 | expected_itraj = 0 19 | for itraj, X in r.iterator(chunk=0): 20 | assert itraj == expected_itraj 21 | expected_itraj += 1 22 | 23 | expected_itraj = -1 24 | it = r.iterator(chunk=16) 25 | for itraj, X in it: 26 | if it.pos == 0: 27 | expected_itraj += 1 28 | assert itraj == expected_itraj == it.current_trajindex 29 | 30 | def test_n_chunks(self): 31 | r = DataInMemory(self.d) 32 | 33 | it0 = r.iterator(chunk=0) 34 | assert it0._n_chunks == 3 # 3 trajs 35 | 36 | it1 = r.iterator(chunk=50) 37 | assert it1._n_chunks == 3 * 2 # 2 chunks per trajectory 38 | 39 | it2 = r.iterator(chunk=30) 40 | # 3 full chunks and 1 small chunk per trajectory 41 | assert it2._n_chunks == 3 * 4 42 | 43 | it3 = r.iterator(chunk=30) 44 | it3.skip = 10 45 | assert it3._n_chunks == 3 * 3 # 3 full chunks per traj 46 | 47 | it4 = r.iterator(chunk=30) 48 | it4.skip = 5 49 | # 3 full chunks and 1 chunk of 5 frames per trajectory 50 | assert it4._n_chunks == 3 * 4 51 | 52 | # test for lagged iterator 53 | for stride in range(1, 5): 54 | for lag in range(0, 18): 55 | it = r.iterator( 56 | lag=lag, chunk=30, stride=stride, return_trajindex=False) 57 | chunks = 0 58 | for _ in it: 59 | chunks += 1 60 | assert chunks == it._n_chunks 61 | 62 | def test_skip(self): 63 | r = DataInMemory(self.d) 64 | lagged_it = r.iterator(lag=5) 65 | assert lagged_it._it.skip == 0 66 | assert lagged_it._it_lagged.skip == 5 67 | 68 | it = r.iterator() 69 | for itraj, X in it: 70 | if itraj == 0: 71 | it.skip = 5 72 | if itraj == 1: 73 | assert it.skip == 5 74 | 75 | def test_chunksize(self): 76 | r = DataInMemory(self.d) 77 | cs = np.arange(1, 17) 78 | i = 0 79 | it = r.iterator(chunk=cs[i]) 80 | for itraj, X in it: 81 | if not it.last_chunk_in_traj: 82 | assert len(X) == it.chunksize 83 | else: 84 | assert len(X) <= it.chunksize 85 | i += 1 86 | i %= len(cs) 87 | it.chunksize = cs[i] 88 | assert it.chunksize == cs[i] 89 | 90 | def test_last_chunk(self): 91 | r = DataInMemory(self.d) 92 | it = r.iterator(chunk=0) 93 | for itraj, X in it: 94 | assert it.last_chunk_in_traj 95 | if itraj == 2: 96 | assert it.last_chunk 97 | 98 | def test_stride(self): 99 | r = DataInMemory(self.d) 100 | stride = np.arange(1, 17) 101 | i = 0 102 | it = r.iterator(stride=stride[i], chunk=1) 103 | for _ in it: 104 | i += 1 105 | i %= len(stride) 106 | it.stride = stride[i] 107 | assert it.stride == stride[i] 108 | 109 | def test_return_trajindex(self): 110 | r = DataInMemory(self.d) 111 | it = r.iterator(chunk=0) 112 | it.return_traj_index = True 113 | assert it.return_traj_index is True 114 | for tup in it: 115 | assert len(tup) == 2 116 | it.reset() 117 | it.return_traj_index = False 118 | assert it.return_traj_index is False 119 | itraj = 0 120 | for tup in it: 121 | np.testing.assert_equal(tup, self.d[itraj]) 122 | itraj += 1 123 | 124 | for tup in r.iterator(return_trajindex=True): 125 | assert len(tup) == 2 126 | itraj = 0 127 | for tup in r.iterator(return_trajindex=False): 128 | np.testing.assert_equal(tup, self.d[itraj]) 129 | itraj += 1 130 | 131 | def test_pos(self): 132 | r = DataInMemory(self.d) 133 | r.chunksize = 17 134 | it = r.iterator() 135 | t = 0 136 | for itraj, X in it: 137 | assert t == it.pos 138 | t += len(X) 139 | if it.last_chunk_in_traj: 140 | t = 0 141 | 142 | def test_write_to_csv_propagate_filenames(self): 143 | from chainsaw import source, tica 144 | with TemporaryDirectory() as td: 145 | data = [np.random.random((20, 3))] * 3 146 | fns = [os.path.join(td, f) 147 | for f in ('blah.npy', 'blub.npy', 'foo.npy')] 148 | for x, fn in zip(data, fns): 149 | np.save(fn, x) 150 | reader = source(fns) 151 | assert reader.filenames == fns 152 | tica_obj = tica(reader, lag=1, dim=2) 153 | tica_obj.write_to_csv(extension=".txt", chunksize=3) 154 | res = sorted([os.path.abspath(x) for x in glob(td + os.path.sep + '*.txt')]) 155 | self.assertEqual(len(res), len(fns)) 156 | desired_fns = sorted([s.replace('.npy', '.txt') for s in fns]) 157 | self.assertEqual(res, desired_fns) 158 | 159 | # compare written results 160 | expected = tica_obj.get_output() 161 | actual = source(list(s.replace('.npy', '.txt') for s in fns)).get_output() 162 | assert len(actual) == len(fns) 163 | for a, e in zip(actual, expected): 164 | np.testing.assert_allclose(a, e) 165 | 166 | if __name__ == '__main__': 167 | unittest.main() 168 | -------------------------------------------------------------------------------- /chainsaw/util/reflection.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | from __future__ import division, print_function, absolute_import 20 | 21 | import inspect 22 | import six 23 | #from six import string_types 24 | from collections import namedtuple 25 | 26 | __author__ = 'noe, marscher' 27 | 28 | 29 | # Add a replacement for inspect.getargspec() which is deprecated in python 3.5 30 | # The version below is borrowed from Django, 31 | # https://github.com/django/django/pull/4846 32 | 33 | # Note an inconsistency between inspect.getargspec(func) and 34 | # inspect.signature(func). If `func` is a bound method, the latter does *not* 35 | # list `self` as a first argument, while the former *does*. 36 | # Hence cook up a common ground replacement: `getargspec_no_self` which 37 | # mimics `inspect.getargspec` but does not list `self`. 38 | # 39 | # This way, the caller code does not need to know whether it uses a legacy 40 | # .getargspec or bright and shiny .signature. 41 | 42 | try: 43 | # is it python 3.3 or higher? 44 | inspect.signature 45 | 46 | # Apparently, yes. Wrap inspect.signature 47 | 48 | ArgSpec = namedtuple('ArgSpec', ['args', 'varargs', 'keywords', 'defaults']) 49 | 50 | def getargspec_no_self(func): 51 | """inspect.getargspec replacement using inspect.signature. 52 | 53 | inspect.getargspec is deprecated in python 3. This is a replacement 54 | based on the (new in python 3.3) `inspect.signature`. 55 | 56 | Parameters 57 | ---------- 58 | func : callable 59 | A callable to inspect 60 | 61 | Returns 62 | ------- 63 | argspec : ArgSpec(args, varargs, varkw, defaults) 64 | This is similar to the result of inspect.getargspec(func) under 65 | python 2.x. 66 | NOTE: if the first argument of `func` is self, it is *not*, I repeat 67 | *not* included in argspec.args. 68 | This is done for consistency between inspect.getargspec() under 69 | python 2.x, and inspect.signature() under python 3.x. 70 | """ 71 | sig = inspect.signature(func) 72 | args = [ 73 | p.name for p in sig.parameters.values() 74 | if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD 75 | ] 76 | varargs = [ 77 | p.name for p in sig.parameters.values() 78 | if p.kind == inspect.Parameter.VAR_POSITIONAL 79 | ] 80 | varargs = varargs[0] if varargs else None 81 | varkw = [ 82 | p.name for p in sig.parameters.values() 83 | if p.kind == inspect.Parameter.VAR_KEYWORD 84 | ] 85 | varkw = varkw[0] if varkw else None 86 | defaults = [ 87 | p.default for p in sig.parameters.values() 88 | if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and 89 | p.default is not p.empty) 90 | ] or None 91 | 92 | if args[0] == 'self': 93 | args.pop(0) 94 | 95 | return ArgSpec(args, varargs, varkw, defaults) 96 | 97 | except AttributeError: 98 | # python 2.x 99 | def getargspec_no_self(func): 100 | """inspect.getargspec replacement for compatibility with python 3.x. 101 | 102 | inspect.getargspec is deprecated in python 3. This wraps it, and 103 | *removes* `self` from the argument list of `func`, if present. 104 | This is done for forward compatibility with python 3. 105 | 106 | Parameters 107 | ---------- 108 | func : callable 109 | A callable to inspect 110 | 111 | Returns 112 | ------- 113 | argspec : ArgSpec(args, varargs, varkw, defaults) 114 | This is similar to the result of inspect.getargspec(func) under 115 | python 2.x. 116 | NOTE: if the first argument of `func` is self, it is *not*, I repeat 117 | *not* included in argspec.args. 118 | This is done for consistency between inspect.getargspec() under 119 | python 2.x, and inspect.signature() under python 3.x. 120 | """ 121 | argspec = inspect.getargspec(func) 122 | if argspec.args[0] == 'self': 123 | argspec.args.pop(0) 124 | return argspec 125 | 126 | 127 | def call_member(obj, f, *args, **kwargs): 128 | """ Calls the specified method, property or attribute of the given object 129 | 130 | Parameters 131 | ---------- 132 | obj : object 133 | The object that will be used 134 | f : str or function 135 | Name of or reference to method, property or attribute 136 | failfast : bool 137 | If True, will raise an exception when trying a method that doesn't exist. If False, will simply return None 138 | in that case 139 | """ 140 | # get function name 141 | if not isinstance(f, six.string_types): 142 | fname = f.__func__.__name__ 143 | else: 144 | fname = f 145 | # get the method ref 146 | method = getattr(obj, fname) 147 | # handle cases 148 | if inspect.ismethod(method): 149 | return method(*args, **kwargs) 150 | 151 | # attribute or property 152 | return method 153 | 154 | 155 | def get_default_args(func): 156 | """ 157 | returns a dictionary of arg_name:default_values for the input function 158 | """ 159 | args, varargs, keywords, defaults = getargspec_no_self(func) 160 | return dict(zip(args[-len(defaults):], defaults)) 161 | -------------------------------------------------------------------------------- /chainsaw/tests/test_featurereader_and_tica_projection.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Test feature reader and Tica by checking the properties of the ICs. 22 | cov(ic_i,ic_j) = delta_ij and cov(ic_i,ic_j,tau) = lambda_i delta_ij 23 | @author: Fabian Paul 24 | ''' 25 | 26 | from __future__ import absolute_import 27 | from __future__ import print_function 28 | 29 | import os 30 | import tempfile 31 | import unittest 32 | from logging import getLogger 33 | 34 | import mdtraj 35 | import numpy as np 36 | from chainsaw.api import tica 37 | from chainsaw.data.md.feature_reader import FeatureReader 38 | from chainsaw.util.contexts import numpy_random_seed 39 | from nose.plugins.attrib import attr 40 | from six.moves import range 41 | 42 | log = getLogger('chainsaw.'+'TestFeatureReaderAndTICAProjection') 43 | 44 | 45 | def random_invertible(n, eps=0.01): 46 | 'generate real random invertible matrix' 47 | m = np.random.randn(n, n) 48 | u, s, v = np.linalg.svd(m) 49 | s = np.maximum(s, eps) 50 | return u.dot(np.diag(s)).dot(v) 51 | 52 | 53 | @attr(slow=True) 54 | class TestFeatureReaderAndTICAProjection(unittest.TestCase): 55 | @classmethod 56 | def setUpClass(cls): 57 | with numpy_random_seed(52): 58 | c = super(TestFeatureReaderAndTICAProjection, cls).setUpClass() 59 | 60 | cls.dim = 99 # dimension (must be divisible by 3) 61 | N = 5000 # length of single trajectory # 500000 # 50000 62 | N_trajs = 10 # number of trajectories 63 | 64 | A = random_invertible(cls.dim) # mixing matrix 65 | # tica will approximate its inverse with the projection matrix 66 | mean = np.random.randn(cls.dim) 67 | 68 | # create topology file 69 | cls.temppdb = tempfile.mktemp('.pdb') 70 | with open(cls.temppdb, 'w') as f: 71 | for i in range(cls.dim // 3): 72 | print(('ATOM %5d C ACE A 1 28.490 31.600 33.379 0.00 1.00' % i), file=f) 73 | 74 | t = np.arange(0, N) 75 | cls.trajnames = [] # list of xtc file names 76 | for i in range(N_trajs): 77 | # set up data 78 | white = np.random.randn(N, cls.dim) 79 | brown = np.cumsum(white, axis=0) 80 | correlated = np.dot(brown, A) 81 | data = correlated + mean 82 | xyz = data.reshape((N, cls.dim // 3, 3)) 83 | # create trajectory file 84 | traj = mdtraj.load(cls.temppdb) 85 | traj.xyz = xyz 86 | traj.time = t 87 | tempfname = tempfile.mktemp('.xtc') 88 | traj.save(tempfname) 89 | cls.trajnames.append(tempfname) 90 | 91 | @classmethod 92 | def tearDownClass(cls): 93 | for fname in cls.trajnames: 94 | os.unlink(fname) 95 | os.unlink(cls.temppdb) 96 | super(TestFeatureReaderAndTICAProjection, cls).tearDownClass() 97 | 98 | def test_covariances_and_eigenvalues(self): 99 | reader = FeatureReader(self.trajnames, self.temppdb) 100 | for tau in [1, 10, 100, 1000, 2000]: 101 | trans = tica(lag=tau, dim=self.dim, kinetic_map=False) 102 | trans.data_producer = reader 103 | 104 | log.info('number of trajectories reported by tica %d' % trans.number_of_trajectories()) 105 | trans.parametrize() 106 | data = trans.get_output() 107 | 108 | log.info('max. eigenvalue: %f' % np.max(trans.eigenvalues)) 109 | self.assertTrue(np.all(trans.eigenvalues <= 1.0)) 110 | # check ICs 111 | check = tica(data=data, lag=tau, dim=self.dim) 112 | 113 | np.testing.assert_allclose(np.eye(self.dim), check.cov, atol=1e-8) 114 | np.testing.assert_allclose(check.mean, 0.0, atol=1e-8) 115 | ic_cov_tau = np.zeros((self.dim, self.dim)) 116 | ic_cov_tau[np.diag_indices(self.dim)] = trans.eigenvalues 117 | np.testing.assert_allclose(ic_cov_tau, check.cov_tau, atol=1e-8) 118 | 119 | def test_partial_fit(self): 120 | from chainsaw import source 121 | reader = source(self.trajnames, top=self.temppdb) 122 | reader_output = reader.get_output() 123 | 124 | params = {'lag': 10, 'kinetic_map': False, 'dim': self.dim} 125 | 126 | tica_obj = tica(**params) 127 | tica_obj.partial_fit(reader_output[0]) 128 | assert not tica_obj._estimated 129 | # acccess eigenvectors to force diagonalization 130 | tica_obj.eigenvectors 131 | assert tica_obj._estimated 132 | 133 | tica_obj.partial_fit(reader_output[1]) 134 | assert not tica_obj._estimated 135 | 136 | tica_obj.eigenvalues 137 | assert tica_obj._estimated 138 | 139 | for traj in reader_output[2:]: 140 | tica_obj.partial_fit(traj) 141 | 142 | # reference 143 | ref = tica(reader, **params) 144 | 145 | np.testing.assert_allclose(tica_obj.cov, ref.cov, atol=1e-15) 146 | np.testing.assert_allclose(tica_obj.cov_tau, ref.cov_tau, atol=1e-15) 147 | 148 | np.testing.assert_allclose(tica_obj.eigenvalues, ref.eigenvalues, atol=1e-15) 149 | # we do not test eigenvectors here, since the system is very metastable and 150 | # we have multiple eigenvalues very close to one. 151 | 152 | if __name__ == "__main__": 153 | unittest.main() 154 | -------------------------------------------------------------------------------- /chainsaw/tests/test_cluster.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | 21 | from __future__ import absolute_import 22 | import unittest 23 | import os 24 | import tempfile 25 | 26 | import numpy as np 27 | 28 | from logging import getLogger 29 | import chainsaw as coor 30 | from chainsaw.util import types 31 | from six.moves import range 32 | 33 | 34 | logger = getLogger('chainsaw.'+'TestReaderUtils') 35 | 36 | 37 | class TestCluster(unittest.TestCase): 38 | 39 | @classmethod 40 | def setUpClass(cls): 41 | super(TestCluster, cls).setUpClass() 42 | cls.dtraj_dir = tempfile.mkdtemp() 43 | 44 | # generate Gaussian mixture 45 | means = [np.array([-3,0]), 46 | np.array([-1,1]), 47 | np.array([0,0]), 48 | np.array([1,-1]), 49 | np.array([4,2])] 50 | widths = [np.array([0.3,2]), 51 | np.array([0.3,2]), 52 | np.array([0.3,2]), 53 | np.array([0.3,2]), 54 | np.array([0.3,2])] 55 | # continuous trajectory 56 | nsample = 1000 57 | cls.T = len(means)*nsample 58 | cls.X = np.zeros((cls.T, 2)) 59 | for i in range(len(means)): 60 | cls.X[i*nsample:(i+1)*nsample,0] = widths[i][0] * np.random.randn() + means[i][0] 61 | cls.X[i*nsample:(i+1)*nsample,1] = widths[i][1] * np.random.randn() + means[i][1] 62 | # cluster in different ways 63 | cls.km = coor.cluster_kmeans(data = cls.X, k = 100) 64 | cls.rs = coor.cluster_regspace(data = cls.X, dmin=0.5) 65 | cls.rt = coor.cluster_uniform_time(data = cls.X, k = 100) 66 | cls.cl = [cls.km, cls.rs, cls.rt] 67 | 68 | def setUp(self): 69 | pass 70 | 71 | def test_chunksize(self): 72 | for c in self.cl: 73 | assert types.is_int(c.chunksize) 74 | 75 | def test_clustercenters(self): 76 | for c in self.cl: 77 | assert c.clustercenters.shape[0] == c.n_clusters 78 | assert c.clustercenters.shape[1] == 2 79 | 80 | def test_data_producer(self): 81 | for c in self.cl: 82 | assert c.data_producer is not None 83 | 84 | def test_describe(self): 85 | for c in self.cl: 86 | desc = c.describe() 87 | assert types.is_string(desc) or types.is_list_of_string(desc) 88 | 89 | def test_dimension(self): 90 | for c in self.cl: 91 | assert types.is_int(c.dimension()) 92 | assert c.dimension() == 1 93 | 94 | def test_dtrajs(self): 95 | for c in self.cl: 96 | assert len(c.dtrajs) == 1 97 | assert c.dtrajs[0].dtype == c.output_type() 98 | assert len(c.dtrajs[0]) == self.T 99 | 100 | def test_get_output(self): 101 | for c in self.cl: 102 | O = c.get_output() 103 | assert types.is_list(O) 104 | assert len(O) == 1 105 | assert types.is_int_matrix(O[0]) 106 | assert O[0].shape[0] == self.T 107 | assert O[0].shape[1] == 1 108 | 109 | def test_in_memory(self): 110 | for c in self.cl: 111 | assert isinstance(c.in_memory, bool) 112 | 113 | def test_iterator(self): 114 | for c in self.cl: 115 | for itraj, chunk in c: 116 | assert types.is_int(itraj) 117 | assert types.is_int_matrix(chunk) 118 | assert chunk.shape[0] <= c.chunksize or c.chunksize == 0 119 | assert chunk.shape[1] == c.dimension() 120 | 121 | def test_map(self): 122 | for c in self.cl: 123 | Y = c.transform(self.X) 124 | assert Y.shape[0] == self.T 125 | assert Y.shape[1] == 1 126 | # test if consistent with get_output 127 | assert np.allclose(Y, c.get_output()[0]) 128 | 129 | def test_n_frames_total(self): 130 | for c in self.cl: 131 | c.n_frames_total() == self.T 132 | 133 | def test_number_of_trajectories(self): 134 | for c in self.cl: 135 | c.number_of_trajectories() == 1 136 | 137 | def test_output_type(self): 138 | for c in self.cl: 139 | assert c.output_type() == np.int32 140 | 141 | def test_parametrize(self): 142 | for c in self.cl: 143 | # nothing should happen 144 | c.parametrize() 145 | 146 | def test_save_dtrajs(self): 147 | extension = ".dtraj" 148 | outdir = self.dtraj_dir 149 | for c in self.cl: 150 | prefix = "test_save_dtrajs_%s" % type(c).__name__ 151 | c.save_dtrajs(trajfiles=None, prefix=prefix, output_dir=outdir, extension=extension) 152 | 153 | names = ["%s_%i%s" % (prefix, i, extension) 154 | for i in range(c.data_producer.number_of_trajectories())] 155 | names = [os.path.join(outdir, n) for n in names] 156 | 157 | # check files with given patterns are there 158 | for f in names: 159 | os.stat(f) 160 | 161 | def test_trajectory_length(self): 162 | for c in self.cl: 163 | assert c.trajectory_length(0) == self.T 164 | with self.assertRaises(IndexError): 165 | c.trajectory_length(1) 166 | 167 | def test_trajectory_lengths(self): 168 | for c in self.cl: 169 | assert len(c.trajectory_lengths()) == 1 170 | assert c.trajectory_lengths()[0] == c.trajectory_length(0) 171 | 172 | 173 | if __name__ == "__main__": 174 | unittest.main() -------------------------------------------------------------------------------- /chainsaw/data/_base/random_accessible.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import numpy as np 4 | import numbers 5 | 6 | import six 7 | 8 | 9 | class NotRandomAccessibleException(Exception): 10 | pass 11 | 12 | 13 | class TrajectoryRandomAccessible(object): 14 | def __init__(self): 15 | self._ra_cuboid = NotImplementedRandomAccessStrategy(self) 16 | self._ra_linear_strategy = NotImplementedRandomAccessStrategy(self) 17 | self._ra_linear_itraj_strategy = NotImplementedRandomAccessStrategy(self) 18 | self._ra_jagged = NotImplementedRandomAccessStrategy(self) 19 | self._is_random_accessible = False 20 | 21 | @property 22 | def is_random_accessible(self): 23 | """ 24 | Check if self._is_random_accessible is set to true and if all the random access strategies are implemented. 25 | Returns 26 | ------- 27 | bool : Returns True if random accessible via strategies and False otherwise. 28 | """ 29 | return self._is_random_accessible and \ 30 | not isinstance(self.ra_itraj_cuboid, NotImplementedRandomAccessStrategy) and \ 31 | not isinstance(self.ra_linear, NotImplementedRandomAccessStrategy) and \ 32 | not isinstance(self.ra_itraj_jagged, NotImplementedRandomAccessStrategy) and \ 33 | not isinstance(self.ra_itraj_linear, NotImplementedRandomAccessStrategy) 34 | 35 | @property 36 | def ra_itraj_cuboid(self): 37 | """ 38 | Implementation of random access with slicing that can be up to 3-dimensional, where the first dimension 39 | corresponds to the trajectory index, the second dimension corresponds to the frames and the third dimension 40 | corresponds to the dimensions of the frames. 41 | 42 | The with the frame slice selected frames will be loaded from each in the trajectory-slice selected trajectories 43 | and then sliced with the dimension slice. For example: The data consists out of three trajectories with length 44 | 10, 20, 10, respectively. The slice `data[:, :15, :3]` returns a 3D array of shape (3, 10, 3), where the first 45 | component corresponds to the three trajectories, the second component corresponds to 10 frames (note that 46 | the last 5 frames are being truncated as the other two trajectories only have 10 frames) and the third component 47 | corresponds to the selected first three dimensions. 48 | 49 | :return: Returns an object that allows access by slices in the described manner. 50 | """ 51 | if not self._is_random_accessible: 52 | raise NotRandomAccessibleException() 53 | return self._ra_cuboid 54 | 55 | @property 56 | def ra_itraj_jagged(self): 57 | """ 58 | Behaves like ra_itraj_cuboid just that the trajectories are not truncated and returned as a list. 59 | 60 | :return: Returns an object that allows access by slices in the described manner. 61 | """ 62 | if not self._is_random_accessible: 63 | raise NotRandomAccessibleException() 64 | return self._ra_jagged 65 | 66 | @property 67 | def ra_linear(self): 68 | """ 69 | Implementation of random access that takes a (maximal) two-dimensional slice where the first component 70 | corresponds to the frames and the second component corresponds to the dimensions. Here it is assumed that 71 | the frame indexing is contiguous, i.e., the first frame of the second trajectory has the index of the last frame 72 | of the first trajectory plus one. 73 | 74 | :return: Returns an object that allows access by slices in the described manner. 75 | """ 76 | if not self._is_random_accessible: 77 | raise NotRandomAccessibleException() 78 | return self._ra_linear_strategy 79 | 80 | @property 81 | def ra_itraj_linear(self): 82 | """ 83 | Implementation of random access that takes arguments as the default random access (i.e., up to three dimensions 84 | with trajs, frames and dims, respectively), but which considers the frame indexing to be contiguous. Therefore, 85 | it returns a simple 2D array. 86 | 87 | :return: A 2D array of the sliced data containing [frames, dims]. 88 | """ 89 | if not self._is_random_accessible: 90 | raise NotRandomAccessibleException() 91 | return self._ra_linear_itraj_strategy 92 | 93 | 94 | class RandomAccessStrategy(six.with_metaclass(ABCMeta)): 95 | """ 96 | Abstract parent class for all random access strategies. It holds its corresponding data source and 97 | implements `__getitem__` as well as `__getslice__`, which both get delegated to `_handle_slice`. 98 | """ 99 | def __init__(self, source, max_slice_dimension=-1): 100 | self._source = source 101 | self._max_slice_dimension = max_slice_dimension 102 | 103 | @abstractmethod 104 | def _handle_slice(self, idx): 105 | pass 106 | 107 | @property 108 | def max_slice_dimension(self): 109 | """ 110 | Property that returns how many dimensions the slice can have. 111 | Returns 112 | ------- 113 | int : the maximal slice dimension 114 | """ 115 | return self._max_slice_dimension 116 | 117 | def __getitem__(self, idx): 118 | return self._handle_slice(idx) 119 | 120 | def __getslice__(self, start, stop): 121 | """For slices of the form data[1:3].""" 122 | return self.__getitem__(slice(start, stop)) 123 | 124 | def _get_indices(self, item, length): 125 | if isinstance(item, slice): 126 | item = np.array(range(*item.indices(length))) 127 | elif not isinstance(item, np.ndarray): 128 | if isinstance(item, list): 129 | item = np.array(item) 130 | else: 131 | item = np.arange(0, length)[item] 132 | if isinstance(item, numbers.Integral): 133 | item = np.array([item]) 134 | return item 135 | 136 | def _max(self, elems): 137 | if isinstance(elems, numbers.Integral): 138 | elems = [elems] 139 | return max(elems) 140 | 141 | 142 | class NotImplementedRandomAccessStrategy(RandomAccessStrategy): 143 | def _handle_slice(self, idx): 144 | raise NotImplementedError("Requested random access strategy is not implemented for the current data source.") 145 | -------------------------------------------------------------------------------- /chainsaw/util/linalg.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | from __future__ import absolute_import 21 | import numpy as np 22 | import scipy.linalg 23 | import scipy.sparse 24 | import copy 25 | import math 26 | from six.moves import range 27 | 28 | __author__ = 'noe' 29 | 30 | 31 | def mdot(*args): 32 | """Computes a matrix product of multiple ndarrays 33 | 34 | This is a convenience function to avoid constructs such as np.dot(A, np.dot(B, np.dot(C, D))) and instead 35 | use mdot(A, B, C, D). 36 | 37 | Parameters 38 | ---------- 39 | *args : an arbitrarily long list of ndarrays that must be compatible for multiplication, 40 | i.e. args[i].shape[1] = args[i+1].shape[0]. 41 | """ 42 | if len(args) < 1: 43 | raise ValueError('need at least one argument') 44 | elif len(args) == 1: 45 | return args[0] 46 | elif len(args) == 2: 47 | return np.dot(args[0],args[1]) 48 | else: 49 | return np.dot(args[0], mdot(*args[1:])) 50 | 51 | 52 | def submatrix(M, sel): 53 | """Returns a submatrix of the quadratic matrix M, given by the selected columns and row 54 | 55 | Parameters 56 | ---------- 57 | M : ndarray(n,n) 58 | symmetric matrix 59 | sel : int-array 60 | selection of rows and columns. Element i,j will be selected if both are in sel. 61 | 62 | Returns 63 | ------- 64 | S : ndarray(m,m) 65 | submatrix with m=len(sel) 66 | 67 | """ 68 | assert len(M.shape) == 2, 'M is not a matrix' 69 | assert M.shape[0] == M.shape[1], 'M is not quadratic' 70 | 71 | """Row slicing""" 72 | if scipy.sparse.issparse(M): 73 | C_cc = M.tocsr() 74 | else: 75 | C_cc = M 76 | C_cc = C_cc[sel, :] 77 | 78 | """Column slicing""" 79 | if scipy.sparse.issparse(M): 80 | C_cc = C_cc.tocsc() 81 | C_cc = C_cc[:, sel] 82 | 83 | if scipy.sparse.issparse(M): 84 | return C_cc.tocoo() 85 | else: 86 | return C_cc 87 | 88 | 89 | def _sort_by_norm(evals, evecs): 90 | """ 91 | Sorts the eigenvalues and eigenvectors by descending norm of the eigenvalues 92 | 93 | Parameters 94 | ---------- 95 | evals: ndarray(n) 96 | eigenvalues 97 | evecs: ndarray(n,n) 98 | eigenvectors in a column matrix 99 | 100 | Returns 101 | ------- 102 | (evals, evecs) : ndarray(m), ndarray(n,m) 103 | the sorted eigenvalues and eigenvectors 104 | 105 | """ 106 | # norms 107 | evnorms = np.abs(evals) 108 | # sort 109 | I = np.argsort(evnorms)[::-1] 110 | # permute 111 | evals2 = evals[I] 112 | evecs2 = evecs[:, I] 113 | # done 114 | return (evals2, evecs2) 115 | 116 | 117 | def eig_corr(C0, Ct, epsilon=1e-6): 118 | r""" Solve generalized eigenvalues problem with correlation matrices C0 and Ct 119 | 120 | Numerically robust solution of a generalized eigenvalue problem of the form 121 | 122 | .. math:: 123 | \mathbf{C}_t \mathbf{r}_i = \mathbf{C}_0 \mathbf{r}_i l_i 124 | 125 | Computes :math:`m` dominant eigenvalues :math:`l_i` and eigenvectors :math:`\mathbf{r}_i`, where 126 | :math:`m` is the numerical rank of the problem. This is done by first conducting a Schur decomposition 127 | of the symmetric positive matrix :math:`\mathbf{C}_0`, then truncating its spectrum to retain only eigenvalues 128 | that are numerically greater than zero, then using this decomposition to define an ordinary eigenvalue 129 | Problem for :math:`\mathbf{C}_t` of size :math:`m`, and then solving this eigenvalue problem. 130 | 131 | Parameters 132 | ---------- 133 | C0 : ndarray (n,n) 134 | time-instantaneous correlation matrix. Must be symmetric positive definite 135 | Ct : ndarray (n,n) 136 | time-lagged correlation matrix. Must be symmetric 137 | epsilon : float 138 | eigenvalue norm cutoff. Eigenvalues of C0 with norms <= epsilon will be 139 | cut off. The remaining number of Eigenvalues define the size of 140 | the output. 141 | 142 | Returns 143 | ------- 144 | l : ndarray (m) 145 | The first m generalized eigenvalues, sorted by descending norm 146 | R : ndarray (n,m) 147 | The first m generalized eigenvectors, as a column matrix. 148 | 149 | """ 150 | # check input 151 | assert np.allclose(C0.T, C0), 'C0 is not a symmetric matrix' 152 | assert np.allclose(Ct.T, Ct), 'Ct is not a symmetric matrix' 153 | 154 | # compute the Eigenvalues of C0 using Schur factorization 155 | (S, V) = scipy.linalg.schur(C0) 156 | s = np.diag(S) 157 | (s, V) = _sort_by_norm(s, V) # sort them 158 | 159 | # determine the cutoff. We know that C0 is an spd matrix, 160 | # so we select the truncation threshold such that everything that is negative vanishes 161 | evmin = np.min(s) 162 | if evmin < 0: 163 | epsilon = max(epsilon, -evmin + 1e-16) 164 | 165 | # determine effective rank m and perform low-rank approximations. 166 | evnorms = np.abs(s) 167 | n = np.shape(evnorms)[0] 168 | m = n - np.searchsorted(evnorms[::-1], epsilon) 169 | Vm = V[:, 0:m] 170 | sm = s[0:m] 171 | 172 | # transform Ct to orthogonal basis given by the eigenvectors of C0 173 | Sinvhalf = 1.0 / np.sqrt(sm) 174 | T = np.dot(np.diag(Sinvhalf), Vm.T) 175 | Ct_trans = np.dot(np.dot(T, Ct), T.T) 176 | 177 | # solve the symmetric eigenvalue problem in the new basis 178 | (l, R_trans) = scipy.linalg.eigh(Ct_trans) 179 | (l, R_trans) = _sort_by_norm(l, R_trans) 180 | 181 | # transform the eigenvectors back to the old basis 182 | R = np.dot(T.T, R_trans) 183 | 184 | # return result 185 | return (l, R) 186 | -------------------------------------------------------------------------------- /chainsaw/_ext/variational_estimators/tests/benchmark_moments.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | __author__ = 'noe' 4 | 5 | import time 6 | import numpy as np 7 | from .. import moments 8 | 9 | def genS(N): 10 | """ Generates sparsities given N (number of cols) """ 11 | S = [10, 90, 100, 500, 900, 1000, 2000, 5000, 7500, 9000, 10000, 20000, 50000, 75000, 90000] # non-zero 12 | return [s for s in S if s <= N] 13 | 14 | 15 | def genX(L, N, n_var=None, const=False): 16 | X = np.random.rand(L, N) # random data 17 | if n_var is not None: 18 | if const: 19 | Xsparse = np.ones((L, N)) 20 | else: 21 | Xsparse = np.zeros((L, N)) 22 | Xsparse[:, :n_var] = X[:, :n_var] 23 | X = Xsparse 24 | return X 25 | 26 | 27 | def genY(L, N, n_var=None, const=False): 28 | X = np.random.rand(L, N) # random data 29 | if n_var is not None: 30 | if const: 31 | Xsparse = -np.ones((L, N)) 32 | else: 33 | Xsparse = np.zeros((L, N)) 34 | Xsparse[:, :n_var] = X[:, :n_var] 35 | X = Xsparse 36 | return X 37 | 38 | 39 | def reftime_momentsXX(X, remove_mean=False, nrep=3): 40 | # time for reference calculation 41 | t1 = time.time() 42 | for r in range(nrep): 43 | s_ref = X.sum(axis=0) # computation of mean 44 | if remove_mean: 45 | X = X - s_ref/float(X.shape[0]) 46 | C_XX_ref = np.dot(X.T, X) # covariance matrix 47 | t2 = time.time() 48 | # return mean time 49 | return (t2-t1)/float(nrep) 50 | 51 | 52 | def mytime_momentsXX(X, remove_mean=False, nrep=3): 53 | # time for reference calculation 54 | t1 = time.time() 55 | for r in range(nrep): 56 | w, s, C_XX = moments.moments_XX(X, remove_mean=remove_mean) 57 | t2 = time.time() 58 | # return mean time 59 | return (t2-t1)/float(nrep) 60 | 61 | 62 | def reftime_momentsXXXY(X, Y, remove_mean=False, symmetrize=False, nrep=3): 63 | # time for reference calculation 64 | t1 = time.time() 65 | for r in range(nrep): 66 | sx = X.sum(axis=0) # computation of mean 67 | sy = Y.sum(axis=0) # computation of mean 68 | if symmetrize: 69 | sx = 0.5*(sx + sy) 70 | sy = sx 71 | if remove_mean: 72 | X = X - sx/float(X.shape[0]) 73 | Y = Y - sy/float(Y.shape[0]) 74 | if symmetrize: 75 | C_XX_ref = np.dot(X.T, X) + np.dot(Y.T, Y) 76 | C_XY = np.dot(X.T, Y) 77 | C_XY_ref = C_XY + C_XY.T 78 | else: 79 | C_XX_ref = np.dot(X.T, X) 80 | C_XY_ref = np.dot(X.T, Y) 81 | t2 = time.time() 82 | # return mean time 83 | return (t2-t1)/float(nrep) 84 | 85 | 86 | def mytime_momentsXXXY(X, Y, remove_mean=False, symmetrize=False, nrep=3): 87 | # time for reference calculation 88 | t1 = time.time() 89 | for r in range(nrep): 90 | w, sx, sy, C_XX, C_XY = moments.moments_XXXY(X, Y, remove_mean=remove_mean, symmetrize=symmetrize) 91 | t2 = time.time() 92 | # return mean time 93 | return (t2-t1)/float(nrep) 94 | 95 | 96 | def benchmark_moments(L=10000, N=10000, nrep=5, xy=False, remove_mean=False, symmetrize=False, const=False): 97 | #S = [10, 100, 1000] 98 | S = genS(N) 99 | 100 | # time for reference calculation 101 | X = genX(L, N) 102 | if xy: 103 | Y = genY(L, N) 104 | reftime = reftime_momentsXXXY(X, Y, remove_mean=remove_mean, symmetrize=symmetrize, nrep=nrep) 105 | else: 106 | reftime = reftime_momentsXX(X, remove_mean=remove_mean, nrep=nrep) 107 | 108 | # my time 109 | times = np.zeros(len(S)) 110 | for k, s in enumerate(S): 111 | X = genX(L, N, n_var=s, const=const) 112 | if xy: 113 | Y = genY(L, N, n_var=s, const=const) 114 | times[k] = mytime_momentsXXXY(X, Y, remove_mean=remove_mean, symmetrize=symmetrize, nrep=nrep) 115 | else: 116 | times[k] = mytime_momentsXX(X, remove_mean=remove_mean, nrep=nrep) 117 | 118 | # assemble report 119 | rows = ['L, data points', 'N, dimensions', 'S, nonzeros', 'time trivial', 'time moments_XX', 'speed-up'] 120 | table = np.zeros((6, len(S))) 121 | table[0, :] = L 122 | table[1, :] = N 123 | table[2, :] = S 124 | table[3, :] = reftime 125 | table[4, :] = times 126 | table[5, :] = reftime / times 127 | 128 | # print table 129 | if xy: 130 | fname = 'moments_XXXY' 131 | else: 132 | fname = 'moments_XX' 133 | print(fname + '\tremove_mean = ' + str(remove_mean) + '\tsym = ' + str(symmetrize) + '\tconst = ' + str(const)) 134 | print(rows[0] + ('\t%i' * table.shape[1])%tuple(table[0])) 135 | print(rows[1] + ('\t%i' * table.shape[1])%tuple(table[1])) 136 | print(rows[2] + ('\t%i' * table.shape[1])%tuple(table[2])) 137 | print(rows[3] + ('\t%.3f' * table.shape[1])%tuple(table[3])) 138 | print(rows[4] + ('\t%.3f' * table.shape[1])%tuple(table[4])) 139 | print(rows[5] + ('\t%.3f' * table.shape[1])%tuple(table[5])) 140 | print() 141 | 142 | 143 | def main(): 144 | LNs = [(100000, 100, 10), (10000, 1000, 7), (1000, 2000, 5), (250, 5000, 5), (100, 10000, 5)] 145 | for L, N, nrep in LNs: 146 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=False, symmetrize=False, const=False) 147 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=False, symmetrize=False, const=True) 148 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=True, symmetrize=False, const=False) 149 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=True, symmetrize=False, const=True) 150 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=False, const=False) 151 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=False, const=True) 152 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=True, const=False) 153 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=True, const=True) 154 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=False, const=False) 155 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=False, const=True) 156 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=True, const=False) 157 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=True, const=True) 158 | 159 | 160 | if __name__ == "__main__": 161 | main() -------------------------------------------------------------------------------- /chainsaw/clustering/regspace.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Created on 26.01.2015 22 | 23 | @author: marscher 24 | ''' 25 | 26 | from __future__ import absolute_import 27 | 28 | import warnings 29 | 30 | from . import _regspatial 31 | from .interface import AbstractClustering 32 | from chainsaw.util.annotators import fix_docs 33 | from chainsaw.util.exceptions import NotConvergedWarning 34 | 35 | import numpy as np 36 | 37 | 38 | __all__ = ['RegularSpaceClustering'] 39 | 40 | 41 | @fix_docs 42 | class RegularSpaceClustering(AbstractClustering): 43 | r"""Regular space clustering""" 44 | 45 | def __init__(self, dmin, max_centers=1000, metric='euclidean', stride=1, n_jobs=None, skip=0): 46 | """Clusters data objects in such a way, that cluster centers are at least in 47 | distance of dmin to each other according to the given metric. 48 | The assignment of data objects to cluster centers is performed by 49 | Voronoi partioning. 50 | 51 | Regular space clustering [Prinz_2011]_ is very similar to Hartigan's leader 52 | algorithm [Hartigan_1975]_. It consists of two passes through 53 | the data. Initially, the first data point is added to the list of centers. 54 | For every subsequent data point, if it has a greater distance than dmin from 55 | every center, it also becomes a center. In the second pass, a Voronoi 56 | discretization with the computed centers is used to partition the data. 57 | 58 | 59 | Parameters 60 | ---------- 61 | dmin : float 62 | minimum distance between all clusters. 63 | metric : str 64 | metric to use during clustering ('euclidean', 'minRMSD') 65 | max_centers : int 66 | if this cutoff is hit during finding the centers, 67 | the algorithm will abort. 68 | n_jobs : int or None, default None 69 | Number of threads to use during assignment of the data. 70 | If None, all available CPUs will be used. 71 | 72 | References 73 | ---------- 74 | 75 | .. [Prinz_2011] Prinz J-H, Wu H, Sarich M, Keller B, Senne M, Held M, Chodera JD, Schuette Ch and Noe F. 2011. 76 | Markov models of molecular kinetics: Generation and Validation. 77 | J. Chem. Phys. 134, 174105. 78 | .. [Hartigan_1975] Hartigan J. Clustering algorithms. 79 | New York: Wiley; 1975. 80 | 81 | """ 82 | super(RegularSpaceClustering, self).__init__(metric=metric, n_jobs=n_jobs) 83 | 84 | self.set_params(dmin=dmin, metric=metric, 85 | max_centers=max_centers, stride=stride, skip=skip) 86 | 87 | def describe(self): 88 | return "[RegularSpaceClustering dmin=%f, inp_dim=%i]" % (self._dmin, self.data_producer.dimension()) 89 | 90 | @property 91 | def dmin(self): 92 | """Minimum distance between cluster centers.""" 93 | return self._dmin 94 | 95 | @dmin.setter 96 | def dmin(self, d): 97 | if d < 0: 98 | raise ValueError("d has to be positive") 99 | 100 | self._dmin = float(d) 101 | self._estimated = False 102 | 103 | @property 104 | def max_centers(self): 105 | """ 106 | Cutoff during clustering. If reached no more data is taken into account. 107 | You might then consider a larger value or a larger dmin value. 108 | """ 109 | return self._max_centers 110 | 111 | @max_centers.setter 112 | def max_centers(self, value): 113 | if value < 0: 114 | raise ValueError("max_centers has to be positive") 115 | 116 | self._max_centers = int(value) 117 | self._estimated = False 118 | 119 | @property 120 | def n_clusters(self): 121 | return self.max_centers 122 | 123 | @n_clusters.setter 124 | def n_clusters(self, val): 125 | self.max_centers = val 126 | 127 | def _estimate(self, iterable, **kwargs): 128 | ######## 129 | # Calculate clustercenters: 130 | # 1. choose first datapoint as centroid 131 | # 2. for all X: calc distances to all clustercenters 132 | # 3. add new centroid, if min(distance to all other clustercenters) >= dmin 133 | ######## 134 | # temporary list to store cluster centers 135 | clustercenters = [] 136 | used_frames = 0 137 | it = iterable.iterator(return_trajindex=False, stride=self.stride, 138 | chunk=self.chunksize, skip=self.skip) 139 | try: 140 | with it: 141 | for X in it: 142 | used_frames += len(X) 143 | _regspatial.cluster(X.astype(np.float32, order='C', copy=False), 144 | clustercenters, self.dmin, 145 | self.metric, self.max_centers) 146 | except RuntimeError: 147 | msg = 'Maximum number of cluster centers reached.' \ 148 | ' Consider increasing max_centers or choose' \ 149 | ' a larger minimum distance, dmin.' 150 | self._logger.warning(msg) 151 | warnings.warn(msg) 152 | # finished anyway, because we have no more space for clusters. Rest of trajectory has no effect 153 | clustercenters = np.array(clustercenters) 154 | self.update_model_params(clustercenters=clustercenters, 155 | n_cluster=len(clustercenters)) 156 | # pass amount of processed data 157 | used_data = used_frames / float(it.n_frames_total()) * 100.0 158 | raise NotConvergedWarning("Used data for centers: %.2f%%" % used_data) 159 | 160 | clustercenters = np.array(clustercenters) 161 | self.update_model_params(clustercenters=clustercenters, 162 | n_clusters=len(clustercenters)) 163 | 164 | if len(clustercenters) == 1: 165 | self._logger.warning('Have found only one center according to ' 166 | 'minimum distance requirement of %f' % self.dmin) 167 | 168 | return self 169 | -------------------------------------------------------------------------------- /chainsaw/tests/test_pca.py: -------------------------------------------------------------------------------- 1 | 2 | # This file is part of PyEMMA. 3 | # 4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 5 | # 6 | # PyEMMA is free software: you can redistribute it and/or modify 7 | # it under the terms of the GNU Lesser General Public License as published by 8 | # the Free Software Foundation, either version 3 of the License, or 9 | # (at your option) any later version. 10 | # 11 | # This program is distributed in the hope that it will be useful, 12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 | # GNU General Public License for more details. 15 | # 16 | # You should have received a copy of the GNU Lesser General Public License 17 | # along with this program. If not, see . 18 | 19 | 20 | ''' 21 | Created on 02.02.2015 22 | 23 | @author: marscher 24 | ''' 25 | 26 | from __future__ import absolute_import 27 | import unittest 28 | 29 | import numpy as np 30 | 31 | from chainsaw import pca 32 | from logging import getLogger 33 | from chainsaw.util import types 34 | from six.moves import range 35 | 36 | 37 | logger = getLogger('chainsaw.'+'TestPCA') 38 | 39 | 40 | class TestPCAExtensive(unittest.TestCase): 41 | 42 | @classmethod 43 | def setUpClass(cls): 44 | import msmtools.generation as msmgen 45 | 46 | # set random state, remember old one and set it back in tearDownClass 47 | cls.old_state = np.random.get_state() 48 | np.random.seed(0) 49 | 50 | # generate HMM with two Gaussians 51 | cls.P = np.array([[0.99, 0.01], 52 | [0.01, 0.99]]) 53 | cls.T = 10000 54 | means = [np.array([-1,1]), np.array([1,-1])] 55 | widths = [np.array([0.3,2]),np.array([0.3,2])] 56 | # continuous trajectory 57 | cls.X = np.zeros((cls.T, 2)) 58 | # hidden trajectory 59 | dtraj = msmgen.generate_traj(cls.P, cls.T) 60 | for t in range(cls.T): 61 | s = dtraj[t] 62 | cls.X[t,0] = widths[s][0] * np.random.randn() + means[s][0] 63 | cls.X[t,1] = widths[s][1] * np.random.randn() + means[s][1] 64 | cls.pca_obj = pca(data = cls.X, dim=1) 65 | 66 | @classmethod 67 | def tearDownClass(cls): 68 | np.random.set_state(cls.old_state) 69 | 70 | def test_chunksize(self): 71 | assert types.is_int(self.pca_obj.chunksize) 72 | 73 | def test_variances(self): 74 | obj = pca(data = self.X) 75 | O = obj.get_output()[0] 76 | vars = np.var(O, axis=0) 77 | refs = obj.eigenvalues 78 | assert np.max(np.abs(vars - refs)) < 0.01 79 | 80 | def test_cumvar(self): 81 | assert len(self.pca_obj.cumvar) == 2 82 | assert np.allclose(self.pca_obj.cumvar[-1], 1.0) 83 | 84 | def test_cov(self): 85 | cov_ref = np.dot(self.X.T, self.X) / float(self.T) 86 | assert(np.all(self.pca_obj.cov.shape == cov_ref.shape)) 87 | assert(np.max(self.pca_obj.cov - cov_ref) < 3e-2) 88 | 89 | def test_data_producer(self): 90 | assert self.pca_obj.data_producer is not None 91 | 92 | def test_describe(self): 93 | desc = self.pca_obj.describe() 94 | assert types.is_string(desc) or types.is_list_of_string(desc) 95 | 96 | def test_dimension(self): 97 | assert types.is_int(self.pca_obj.dimension()) 98 | # Here: 99 | assert self.pca_obj.dimension() == 1 100 | # Test other variants 101 | obj = pca(data=self.X, dim=-1, var_cutoff=1.0) 102 | assert obj.dimension() == 2 103 | obj = pca(data=self.X, dim=-1, var_cutoff=0.8) 104 | assert obj.dimension() == 1 105 | with self.assertRaises(ValueError): # trying to set both dim and subspace_variance is forbidden 106 | pca(data=self.X, dim=1, var_cutoff=0.8) 107 | 108 | def test_eigenvalues(self): 109 | eval = self.pca_obj.eigenvalues 110 | assert len(eval) == 2 111 | 112 | def test_eigenvectors(self): 113 | evec = self.pca_obj.eigenvectors 114 | assert(np.all(evec.shape == (2,2))) 115 | 116 | def test_get_output(self): 117 | O = self.pca_obj.get_output() 118 | assert types.is_list(O) 119 | assert len(O) == 1 120 | assert types.is_float_matrix(O[0]) 121 | assert O[0].shape[0] == self.T 122 | assert O[0].shape[1] == self.pca_obj.dimension() 123 | 124 | def test_in_memory(self): 125 | assert isinstance(self.pca_obj.in_memory, bool) 126 | 127 | def test_iterator(self): 128 | for itraj, chunk in self.pca_obj: 129 | assert types.is_int(itraj) 130 | assert types.is_float_matrix(chunk) 131 | assert chunk.shape[1] == self.pca_obj.dimension() 132 | 133 | def test_map(self): 134 | Y = self.pca_obj.transform(self.X) 135 | assert Y.shape[0] == self.T 136 | assert Y.shape[1] == 1 137 | # test if consistent with get_output 138 | assert np.allclose(Y, self.pca_obj.get_output()[0]) 139 | 140 | def test_mean(self): 141 | mean = self.pca_obj.mean 142 | assert len(mean) == 2 143 | assert np.max(mean < 0.5) 144 | 145 | def test_n_frames_total(self): 146 | # map not defined for source 147 | assert self.pca_obj.n_frames_total() == self.T 148 | 149 | def test_number_of_trajectories(self): 150 | # map not defined for source 151 | assert self.pca_obj.number_of_trajectories() == 1 152 | 153 | def test_output_type(self): 154 | assert self.pca_obj.output_type() == np.float32 155 | 156 | def test_trajectory_length(self): 157 | assert self.pca_obj.trajectory_length(0) == self.T 158 | with self.assertRaises(IndexError): 159 | self.pca_obj.trajectory_length(1) 160 | 161 | def test_trajectory_lengths(self): 162 | assert len(self.pca_obj.trajectory_lengths()) == 1 163 | assert self.pca_obj.trajectory_lengths()[0] == self.pca_obj.trajectory_length(0) 164 | 165 | def test_provided_means(self): 166 | data = np.random.random((300, 3)) 167 | mean = data.mean(axis=0) 168 | pca_spec_mean = pca(data, mean=mean) 169 | pca_calc_mean = pca(data) 170 | 171 | np.testing.assert_allclose(mean, pca_calc_mean.mean) 172 | np.testing.assert_allclose(mean, pca_spec_mean.mean) 173 | 174 | np.testing.assert_allclose(pca_spec_mean.cov, pca_calc_mean.cov) 175 | 176 | def test_partial_fit(self): 177 | data = [np.random.random((100, 3)), np.random.random((100, 3))] 178 | pca_part = pca() 179 | pca_part.partial_fit(data[0]) 180 | pca_part.partial_fit(data[1]) 181 | 182 | ref = pca(data) 183 | np.testing.assert_allclose(pca_part.mean, ref.mean) 184 | 185 | np.testing.assert_allclose(pca_part.eigenvalues, ref.eigenvalues) 186 | np.testing.assert_allclose(pca_part.eigenvectors, ref.eigenvectors) 187 | 188 | if __name__ == "__main__": 189 | unittest.main() 190 | -------------------------------------------------------------------------------- /chainsaw/data/numpy_filereader.py: -------------------------------------------------------------------------------- 1 | # This file is part of PyEMMA. 2 | # 3 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER) 4 | # 5 | # PyEMMA is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU Lesser General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | # 10 | # This program is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | # 15 | # You should have received a copy of the GNU Lesser General Public License 16 | # along with this program. If not, see . 17 | ''' 18 | Created on 07.04.2015 19 | 20 | @author: marscher 21 | ''' 22 | 23 | from __future__ import absolute_import 24 | 25 | import functools 26 | 27 | import numpy as np 28 | 29 | from chainsaw.data._base.datasource import DataSourceIterator, DataSource 30 | from chainsaw.data.util.traj_info_cache import TrajInfo 31 | from chainsaw.util.annotators import fix_docs 32 | 33 | from chainsaw.data.util.fileformat_registry import FileFormatRegistry 34 | 35 | 36 | @fix_docs 37 | @FileFormatRegistry.register('.npy') 38 | class NumPyFileReader(DataSource): 39 | 40 | """reads NumPy files in chunks. Supports .npy files 41 | 42 | Parameters 43 | ---------- 44 | filenames : str or list of strings 45 | 46 | chunksize : int 47 | how many rows are read at once 48 | 49 | mmap_mode : str (optional), default='r' 50 | binary NumPy arrays are being memory mapped using this flag. 51 | """ 52 | 53 | def __init__(self, filenames, chunksize=1000, mmap_mode='r', **kw): 54 | super(NumPyFileReader, self).__init__(chunksize=chunksize) 55 | self._is_reader = True 56 | 57 | if not isinstance(filenames, (list, tuple)): 58 | filenames = [filenames] 59 | 60 | self.mmap_mode = mmap_mode 61 | self.filenames = filenames 62 | 63 | def _create_iterator(self, skip=0, chunk=0, stride=1, return_trajindex=False, cols=None): 64 | return NPYIterator(self, skip=skip, chunk=chunk, stride=stride, 65 | return_trajindex=return_trajindex, cols=cols) 66 | 67 | def describe(self): 68 | shapes = [(l, self.ndim) for l 69 | in self._lengths] 70 | return "[NumpyFileReader arrays with shapes: %s]" % shapes 71 | 72 | def _reshape(self, array): 73 | """ 74 | checks shapes, eg convert them (2d), raise if not possible 75 | after checks passed, set self._array and return it. 76 | """ 77 | 78 | if array.ndim == 1: 79 | array = np.atleast_2d(array).T 80 | elif array.ndim == 2: 81 | pass 82 | else: 83 | shape = array.shape 84 | # hold first dimension, multiply the rest 85 | shape_2d = (shape[0], 86 | functools.reduce(lambda x, y: x * y, shape[1:])) 87 | array = np.reshape(array, shape_2d) 88 | return array 89 | 90 | def _load_file(self, itraj): 91 | filename = self._filenames[itraj] 92 | #self._logger.debug("opening file %s" % filename) 93 | 94 | x = np.load(filename, mmap_mode=self.mmap_mode) 95 | arr = self._reshape(x) 96 | return arr 97 | 98 | def _get_traj_info(self, filename): 99 | idx = self.filenames.index(filename) 100 | array = self._load_file(idx) 101 | length, ndim = np.shape(array) 102 | 103 | return TrajInfo(ndim, length) 104 | 105 | 106 | class NPYIterator(DataSourceIterator): 107 | 108 | def __init__(self, data_source, skip=0, chunk=0, stride=1, return_trajindex=False, cols=False): 109 | super(NPYIterator, self).__init__(data_source=data_source, skip=skip, 110 | chunk=chunk, stride=stride, 111 | return_trajindex=return_trajindex, 112 | cols=cols) 113 | 114 | self._last_itraj = -1 115 | 116 | def reset(self): 117 | DataSourceIterator.reset(self) 118 | self._last_itraj = -1 119 | 120 | def close(self): 121 | self._close_filehandle() 122 | 123 | def _close_filehandle(self): 124 | if not hasattr(self, '_array') or self._array is None: 125 | return 126 | del self._array 127 | self._array = None 128 | 129 | def _open_filehandle(self): 130 | self._array = self._data_source._load_file(self._itraj) 131 | 132 | def _next_chunk(self): 133 | if self._itraj >= self._data_source.ntraj: 134 | self.close() 135 | raise StopIteration() 136 | 137 | if self._itraj != self._last_itraj: 138 | self._close_filehandle() 139 | self._open_filehandle() 140 | 141 | traj_len = len(self._array) 142 | traj = self._array 143 | 144 | # skip only if complete trajectory mode or first chunk 145 | skip = self.skip if self.chunksize == 0 or self._t == 0 else 0 146 | 147 | # if stride by dict, update traj length accordingly 148 | if not self.uniform_stride: 149 | traj_len = self.ra_trajectory_length(self._itraj) 150 | 151 | # complete trajectory mode 152 | if self.chunksize == 0: 153 | if not self.uniform_stride: 154 | X = traj[self.ra_indices_for_traj(self._itraj)] 155 | self._itraj += 1 156 | 157 | # skip the trajs that are not in the stride dict 158 | while self._itraj < self.number_of_trajectories() \ 159 | and (self._itraj not in self.traj_keys): 160 | self._itraj += 1 161 | 162 | else: 163 | X = traj[skip::self.stride] 164 | self._itraj += 1 165 | 166 | return X 167 | 168 | # chunked mode 169 | else: 170 | if not self.uniform_stride: 171 | X = traj[self.ra_indices_for_traj(self._itraj)[self._t:min(self._t + self.chunksize, traj_len)]] 172 | upper_bound = min(self._t + self.chunksize, traj_len) 173 | else: 174 | upper_bound = min(skip + self._t + self.chunksize * self.stride, traj_len) 175 | slice_x = slice(skip + self._t, upper_bound, self.stride) 176 | X = traj[slice_x] 177 | 178 | # set new time position 179 | self._t = upper_bound 180 | 181 | if self._t >= traj_len: 182 | self._itraj += 1 183 | self._t = 0 184 | 185 | # if we have a dictionary, skip trajectories that are not in the key set 186 | while not self.uniform_stride and self._itraj < self.number_of_trajectories() \ 187 | and (self._itraj not in self.traj_keys): 188 | self._itraj += 1 189 | 190 | # if time index scope ran out of len of current trajectory, open next file. 191 | if self._itraj <= self.number_of_trajectories() - 1: 192 | self._array = self._data_source._load_file(self._itraj) 193 | 194 | return X 195 | --------------------------------------------------------------------------------