├── chainsaw
├── _ext
│ ├── __init__.py
│ └── variational_estimators
│ │ ├── tests
│ │ ├── __init__.py
│ │ └── benchmark_moments.py
│ │ ├── covar_c
│ │ ├── __init__.py
│ │ └── _covartools.h
│ │ ├── __init__.py
│ │ └── lagged_correlation.py
├── base
│ ├── __init__.py
│ ├── reporter.py
│ ├── estimator.py
│ ├── loggable.py
│ └── model.py
├── data
│ ├── _base
│ │ ├── __init__.py
│ │ └── random_accessible.py
│ ├── md
│ │ ├── featurization
│ │ │ ├── __init__.py
│ │ │ └── _base.py
│ │ └── __init__.py
│ ├── util
│ │ ├── __init__.py
│ │ └── fileformat_registry.py
│ ├── __init__.py
│ └── numpy_filereader.py
├── tests
│ ├── data
│ │ ├── bpti_mini.dcd
│ │ ├── bpti_mini.h5
│ │ ├── bpti_mini.lh5
│ │ ├── bpti_mini.nc
│ │ ├── bpti_mini.trr
│ │ ├── bpti_mini.xtc
│ │ ├── bpti_001-033.xtc
│ │ ├── bpti_034-066.xtc
│ │ ├── bpti_067-100.xtc
│ │ ├── bpti_mini.binpos
│ │ ├── bpti_mini.netcdf
│ │ ├── opsin_Ca_1_frame.pdb.gz
│ │ ├── opsin_aa_1_frame.pdb.gz
│ │ ├── test.pdb
│ │ └── bpti_ca.pdb
│ ├── __init__.py
│ ├── test_format_registry.py
│ ├── test_acf.py
│ ├── test_uniform_time.py
│ ├── test_mini_batch_kmeans.py
│ ├── util.py
│ ├── test_cluster_samples.py
│ ├── test_api_load.py
│ ├── test_cache.py
│ ├── test_featurereader_and_tica.py
│ ├── test_stride.py
│ ├── test_regspace.py
│ ├── test_discretizer.py
│ ├── test_coordinates_iterator.py
│ ├── test_featurereader_and_tica_projection.py
│ ├── test_cluster.py
│ └── test_pca.py
├── util
│ ├── __init__.py
│ ├── contexts.py
│ ├── exceptions.py
│ ├── change_notification.py
│ ├── files.py
│ ├── indices.py
│ ├── stat.py
│ ├── log.py
│ ├── units.py
│ ├── reflection.py
│ └── linalg.py
├── _resources
│ ├── logging.yml
│ └── chainsaw.cfg
├── transform
│ └── __init__.py
├── clustering
│ ├── __init__.py
│ ├── include
│ │ └── clustering.h
│ ├── assign.py
│ ├── uniform_time.py
│ └── regspace.py
├── __init__.py
└── acf.py
├── .gitattributes
├── README.rst
├── devtools
├── conda-recipe
│ ├── build.sh
│ ├── build.sh.orig
│ ├── bld.bat
│ ├── meta.yaml
│ └── run_test.py
└── ci
│ ├── appveyor
│ ├── after_success.bat
│ ├── deploy.ps1
│ ├── runTestsuite.ps1
│ ├── run_with_env.cmd
│ └── transform_xunit_to_appveyor.xsl
│ ├── travis
│ ├── install_miniconda.sh
│ ├── make_docs.sh
│ ├── after_success.sh
│ └── dev_pkgs_del_old.py
│ └── jenkins
│ └── update_versions_json.py
├── .gitignore
├── MANIFEST.in
├── setup.cfg
├── .travis.yml
└── setup_util.py
/chainsaw/_ext/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chainsaw/base/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chainsaw/data/_base/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | chainsaw/_version.py export-subst
2 |
--------------------------------------------------------------------------------
/README.rst:
--------------------------------------------------------------------------------
1 | chainsaw
2 | ========
3 |
4 | desc
5 |
--------------------------------------------------------------------------------
/chainsaw/_ext/variational_estimators/tests/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'noe'
2 |
--------------------------------------------------------------------------------
/devtools/conda-recipe/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | $PYTHON setup.py install
3 |
--------------------------------------------------------------------------------
/chainsaw/_ext/variational_estimators/covar_c/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'noe'
2 |
--------------------------------------------------------------------------------
/chainsaw/data/md/featurization/__init__.py:
--------------------------------------------------------------------------------
1 | from .featurizer import MDFeaturizer, CustomFeature
2 |
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.dcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.dcd
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.h5
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.lh5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.lh5
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.nc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.nc
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.trr:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.trr
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.xtc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.xtc
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_001-033.xtc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_001-033.xtc
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_034-066.xtc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_034-066.xtc
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_067-100.xtc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_067-100.xtc
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.binpos:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.binpos
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_mini.netcdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/bpti_mini.netcdf
--------------------------------------------------------------------------------
/chainsaw/tests/data/opsin_Ca_1_frame.pdb.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/opsin_Ca_1_frame.pdb.gz
--------------------------------------------------------------------------------
/chainsaw/tests/data/opsin_aa_1_frame.pdb.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/markovmodel/coordinates/master/chainsaw/tests/data/opsin_aa_1_frame.pdb.gz
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | build
3 |
4 | *.pyc
5 | *.so
6 | *.egg-info/
7 | *.orig
8 |
9 | */_ext/variational_estimators/covar_c/covartools.c
10 |
11 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include versioneer.py
2 | include chainsaw/_version.py
3 |
4 | recursive-include chainsaw/_resources *
5 | recursive-include chainsaw/tests/data *
6 |
--------------------------------------------------------------------------------
/chainsaw/tests/data/test.pdb:
--------------------------------------------------------------------------------
1 | ATOM 3 CH3 ACE A 1 28.490 31.600 33.379 0.00 1.00
2 | ATOM 7 C ACE A 2 27.760 30.640 34.299 0.00 1.00
3 | ATOM 8 C ACE A 2 27.760 30.640 34.299 0.00 1.00
4 |
5 |
--------------------------------------------------------------------------------
/chainsaw/data/md/__init__.py:
--------------------------------------------------------------------------------
1 | import pkgutil
2 |
3 | loader = pkgutil.find_loader('mdtraj')
4 |
5 | if loader is not None:
6 | from .feature_reader import FeatureReader
7 | from .featurization.featurizer import MDFeaturizer, CustomFeature
8 |
--------------------------------------------------------------------------------
/chainsaw/_ext/variational_estimators/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .moments import moments_XX, moments_XXXY, moments_block
4 | from .moments import covar, covars
5 | from .running_moments import RunningCovar, running_covar
6 |
--------------------------------------------------------------------------------
/devtools/conda-recipe/build.sh.orig:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | $PYTHON setup.py install
3 | <<<<<<< HEAD
4 | version=$($PYTHON -c "from __future__ import print_function; import pyemma; print(pyemma.__version__)")
5 | export PYEMMA_VERSION=$version
6 | =======
7 | >>>>>>> devel
8 |
--------------------------------------------------------------------------------
/devtools/ci/appveyor/after_success.bat:
--------------------------------------------------------------------------------
1 | conda install --yes -q anaconda-client jinja2
2 | cd %PYTHON_MINICONDA%\conda-bld
3 | dir /s /b %PACKAGENAME%-dev-*.tar.bz2 > files.txt
4 | for /F %%filename in (files.txt) do (
5 | echo "uploading file %%~filename"
6 | anaconda -t %BINSTAR_TOKEN% upload --force -u %ORGNAME% -p %PACKAGENAME%-dev %%~filename
7 | )
8 |
--------------------------------------------------------------------------------
/devtools/conda-recipe/bld.bat:
--------------------------------------------------------------------------------
1 | if not defined APPVEYOR (
2 | echo not on appveyor
3 | "%PYTHON%" setup.py install
4 | ) else (
5 | echo on appveyor
6 | cmd /E:ON /V:ON /C %APPVEYOR_BUILD_FOLDER%\devtools\ci\appveyor\run_with_env.cmd "%PYTHON%" setup.py install
7 | )
8 | set build_status=%ERRORLEVEL%
9 | if %build_status% == 1 exit 1
10 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 |
2 | # See the docstring in versioneer.py for instructions. Note that you must
3 | # re-run 'versioneer.py setup' after changing this section, and commit the
4 | # resulting files.
5 |
6 | [versioneer]
7 | VCS = git
8 | style = pep440
9 | versionfile_source = chainsaw/_version.py
10 | versionfile_build = chainsaw/_version.py
11 | tag_prefix = v
12 | #parentdir_prefix =
13 |
14 |
--------------------------------------------------------------------------------
/devtools/ci/appveyor/deploy.ps1:
--------------------------------------------------------------------------------
1 | function deploy() {
2 | # install tools
3 | pip install wheel twine
4 |
5 | # create wheel and win installer
6 | python setup.py bdist_wheel bdist_wininst
7 |
8 | # upload to pypi with twine
9 | twine upload -i $env:myuser -p $env:mypass dist/*
10 | }
11 |
12 | new_tag = ($env:APPVEYOR_REPO_TAG -eq true)
13 | new_tag = true # temporarily enable for all commits
14 |
15 | if (new_tag) {
16 | deploy
17 | }
--------------------------------------------------------------------------------
/chainsaw/base/reporter.py:
--------------------------------------------------------------------------------
1 | from progress_reporter import ProgressReporter as _impl
2 |
3 |
4 | class ProgressReporter(_impl):
5 | @_impl.show_progress.getter
6 | def show_progress(self):
7 | """ whether to show the progress of heavy calculations on this object. """
8 | if not hasattr(self, "_show_progress"):
9 | from chainsaw import config
10 | val = config.show_progress_bars
11 | self._show_progress = val
12 | return self._show_progress
13 |
--------------------------------------------------------------------------------
/devtools/ci/travis/install_miniconda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # make TARGET overrideable with env
4 | : ${TARGET:=$HOME/miniconda}
5 |
6 | function install_miniconda {
7 | if [ -d $TARGET ]; then echo "file exists"; return; fi
8 | echo "installing miniconda to $TARGET"
9 | if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then
10 | platform="Linux"
11 | elif [[ "$TRAVIS_OS_NAME" == "osx" ]]; then
12 | platform="MacOSX"
13 | fi
14 | wget http://repo.continuum.io/miniconda/Miniconda3-latest-$platform-x86_64.sh -O mc.sh -o /dev/null
15 | bash mc.sh -b -f -p $TARGET
16 | }
17 |
18 | install_miniconda
19 | export PATH=$TARGET/bin:$PATH
20 |
--------------------------------------------------------------------------------
/chainsaw/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
--------------------------------------------------------------------------------
/chainsaw/data/util/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
--------------------------------------------------------------------------------
/chainsaw/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import absolute_import
20 |
--------------------------------------------------------------------------------
/chainsaw/_ext/variational_estimators/covar_c/_covartools.h:
--------------------------------------------------------------------------------
1 | void _subtract_row_double(double* X, double* row, int M, int N);
2 | void _subtract_row_float(double* X, double* row, int M, int N);
3 | void _subtract_row_double_copy(double* X0, double* X, double* row, int M, int N);
4 | int* _bool_to_list(int* b, int N, int nnz);
5 | void _variable_cols_char(int* cols, char* X, int M, int N, int min_constant);
6 | void _variable_cols_int(int* cols, int* X, int M, int N, int min_constant);
7 | void _variable_cols_long(int* cols, long* X, int M, int N, int min_constant);
8 | void _variable_cols_float(int* cols, float* X, int M, int N, int min_constant);
9 | void _variable_cols_double(int* cols, double* X, int M, int N, int min_constant);
10 | void _variable_cols_float_approx(int* cols, float* X, int M, int N, float tol, int min_constant);
11 | void _variable_cols_double_approx(int* cols, double* X, int M, int N, double tol, int min_constant);
12 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: c
2 | git:
3 | submodules: false
4 | os:
5 | - osx
6 | - linux
7 |
8 | sudo: false
9 |
10 | env:
11 | global:
12 | - PATH=$HOME/miniconda/bin:$PATH
13 | - common_py_deps="jinja2 conda-build=1.21.3"
14 | - PACKAGENAME=chainsaw
15 | - ORGNAME=omnia
16 | - DEV_BUILD_N_KEEP=10
17 | matrix:
18 | - CONDA_PY=2.7 CONDA_NPY=1.11
19 | - CONDA_PY=3.4 CONDA_NPY=1.10
20 | - CONDA_PY=3.5 CONDA_NPY=1.11
21 |
22 | before_install:
23 | - devtools/ci/travis/install_miniconda.sh
24 | - conda config --set always_yes true
25 | - conda config --add channels omnia
26 | - conda install -q $common_py_deps
27 |
28 | script:
29 | - conda build -q devtools/conda-recipe
30 |
31 | after_success:
32 | # coverage report: needs .coverage file generated by testsuite and git src
33 | - pip install coveralls
34 | - coveralls
35 | #- if [ "$TRAVIS_SECURE_ENV_VARS" == true ]; then source devtools/ci/travis/after_success.sh; fi
36 |
37 |
--------------------------------------------------------------------------------
/devtools/ci/travis/make_docs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | function install_deps {
3 | wget https://github.com/jgm/pandoc/releases/download/1.13.2/pandoc-1.13.2-1-amd64.deb \
4 | -O pandoc.deb
5 | dpkg -x pandoc.deb $HOME
6 |
7 | export PATH=$PATH:$HOME/usr/bin
8 | # try to execute pandoc
9 | pandoc --version
10 |
11 | conda install -q --yes $doc_deps
12 | pip install -r requirements-build-doc.txt wheel
13 | }
14 |
15 | function build_doc {
16 | pushd doc; make ipython-rst html
17 | # workaround for docs dir => move doc to build/docs afterwards
18 | # travis (currently )expects docs in build/docs (should contain index.html?)
19 | mv build/html ../build/docs
20 | popd
21 | }
22 |
23 | function deploy_doc {
24 | echo "[distutils]
25 | index-servers = pypi
26 |
27 | [pypi]
28 | username:marscher
29 | password:${pypi_pass}" > ~/.pypirc
30 |
31 | python setup.py upload_docs
32 | }
33 |
34 | # build docs only for python 2.7 and for normal commits (not pull requests)
35 | if [[ $TRAVIS_PYTHON_VERSION = "2.7" ]] && [[ "${TRAVIS_PULL_REQUEST}" = "false" ]]; then
36 | install_deps
37 | build_doc
38 | deploy_doc
39 | fi
40 |
--------------------------------------------------------------------------------
/chainsaw/_resources/logging.yml:
--------------------------------------------------------------------------------
1 | # Chainsaw's default logging settings
2 | # If you want to enable file logging, uncomment the file related handlers and handlers
3 | #
4 |
5 | # do not disable other loggers by default.
6 | disable_existing_loggers: False
7 |
8 | # please do not change version, it is an internal variable used by Python.
9 | version: 1
10 |
11 | formatters:
12 | simpleFormater:
13 | format: '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
14 | datefmt: '%d-%m-%y %H:%M:%S'
15 |
16 | handlers:
17 | # log to stdout
18 | console:
19 | class: logging.StreamHandler
20 | formatter: simpleFormater
21 | stream: ext://sys.stdout
22 | # example for rotating log files, disabled by default
23 | #rotating_files:
24 | # class: logging.handlers.RotatingFileHandler
25 | # formatter: simpleFormater
26 | # filename: chainsaw.log
27 | # maxBytes: 1048576 # 1 MB
28 | # backupCount: 3
29 |
30 | loggers:
31 | chainsaw:
32 | level: INFO
33 | # by default no file logging!
34 | handlers: [console] #, rotating_files]
35 |
36 |
--------------------------------------------------------------------------------
/chainsaw/_resources/chainsaw.cfg:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Chainsaw configuration file
3 | #
4 | # notes:
5 | # - comments are not allowed in line, since they would be appended to the value!
6 | ################################################################################
7 |
8 | [chainsaw]
9 | # configuration notice shown?
10 | show_config_notification = False
11 |
12 | # Source to logging configuration file (YAML).
13 | # Special value: DEFAULT (use default config).
14 | # If this is set to a filename, it will be red to configure logging. If it is a
15 | # relative path, it is assumed to be located next to where you start your interpreter.
16 | logging_config = DEFAULT
17 |
18 | # show or hide progress bars globally?
19 | show_progress_bars = True
20 |
21 | # useful for trajectory formats, for which one has to read the whole file to get len
22 | # eg. XTC format.
23 | use_trajectory_lengths_cache = True
24 | # maximum entries in database
25 | traj_info_max_entries = 50000
26 | # max size in MB
27 | traj_info_max_size = 500
28 |
29 | # Cache directory, defaults to the operating systems temporary directory, if set to None
30 | cache_dir = None
31 |
--------------------------------------------------------------------------------
/devtools/ci/appveyor/runTestsuite.ps1:
--------------------------------------------------------------------------------
1 | function xslt_transform($xml, $xsl, $output)
2 | {
3 | trap [Exception]
4 | {
5 | Write-Host $_.Exception
6 | }
7 |
8 | $xslt = New-Object System.Xml.Xsl.XslCompiledTransform
9 | $xslt.Load($xsl)
10 | $xslt.Transform($xml, $output)
11 | }
12 |
13 | function upload($file) {
14 | trap [Exception]
15 | {
16 | Write-Host $_.Exception
17 | }
18 |
19 | $wc = New-Object 'System.Net.WebClient'
20 | $wc.UploadFile("https://ci.appveyor.com/api/testresults/xunit/$($env:APPVEYOR_JOB_ID)", $file)
21 | }
22 |
23 | function run {
24 | cd $env:APPVEYOR_BUILD_FOLDER
25 | $stylesheet = "devtools/ci/appveyor/transform_xunit_to_appveyor.xsl"
26 | $input = "nosetests.xml"
27 | $output = "transformed.xml"
28 |
29 | nosetests pyemma --all-modules --with-xunit -a '!slow'
30 | $success = $?
31 | Write-Host "result code of nosetests:" $success
32 |
33 | xslt_transform $input $stylesheet $output
34 |
35 | upload $output
36 | Push-AppveyorArtifact $input
37 | Push-AppveyorArtifact $output
38 |
39 | # return exit code of testsuite
40 | if ( -not $success) {
41 | throw "testsuite not successful"
42 | }
43 | }
44 |
45 | run
46 |
--------------------------------------------------------------------------------
/devtools/conda-recipe/meta.yaml:
--------------------------------------------------------------------------------
1 | package:
2 | name: chainsaw
3 | # version number: [base tag]+[commits-upstream]_[git_hash]
4 | # eg. v2.0+0_g8824162
5 | version: {{ GIT_DESCRIBE_TAG[1:] + '+' +GIT_BUILD_STR}}
6 | source:
7 | path: ../..
8 |
9 | build:
10 | preserve_egg_dir: True
11 |
12 | requirements:
13 | build:
14 | - python
15 | - setuptools
16 | - cython >=0.20
17 | - mock
18 | - mdtraj # TODO: make it optional?
19 | - funcsigs
20 | - numpy x.x
21 | - h5py
22 | - six
23 | - psutil >=3.1.1
24 | - decorator >=4.0.0
25 | - progress_reporter
26 | - pyyaml
27 | - nomkl
28 | - msmtools # TODO: remove
29 |
30 | run:
31 | - python
32 | - setuptools
33 | - mock
34 | - mdtraj # TODO: make it optional?
35 | - funcsigs
36 | - numpy x.x
37 | - h5py
38 | - six
39 | - psutil >=3.1.1
40 | - decorator >=4.0.0
41 | - progress_reporter
42 | - pyyaml
43 | - nomkl
44 | - msmtools # TODO: remove
45 |
46 | test:
47 | requires:
48 | - nose
49 | - coverage ==4
50 | imports:
51 | - chainsaw
52 |
53 | about:
54 | home: http://emma-project.org
55 | license: GNU Lesser Public License v3+
56 | summary: "data pipe"
57 |
58 |
59 |
--------------------------------------------------------------------------------
/devtools/ci/travis/after_success.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # The pull request number if the current job is a pull request, “false” if it’s not a pull request.
4 | if [[ ! "$TRAVIS_PULL_REQUEST" == "false" ]]; then
5 | echo "This is a pull request. No deployment will be done."; exit 0
6 | fi
7 |
8 | # For builds not triggered by a pull request this is the name of the branch currently being built;
9 | # whereas for builds triggered by a pull request this is the name of the branch targeted by the pull request (in many cases this will be master).
10 | if [ "$TRAVIS_BRANCH" != "devel" ]; then
11 | echo "No deployment on BRANCH='$TRAVIS_BRANCH'"; exit 0
12 | fi
13 |
14 |
15 | # Deploy to binstar
16 | conda install --yes anaconda-client jinja2
17 | pushd .
18 | cd $HOME/miniconda/conda-bld
19 | FILES=*/${PACKAGENAME}-dev-*.tar.bz2
20 | for filename in $FILES; do
21 | anaconda -t $BINSTAR_TOKEN upload --force -u ${ORGNAME} -p ${PACKAGENAME}-dev ${filename}
22 | done
23 | popd
24 |
25 | # call cleanup only for py35, numpy111
26 | if [[ "$CONDA_PY" == "3.5" && "$CONDA_NPY" == "111" && "$TRAVIS_OS_NAME" == "linux" ]]; then
27 | python devtools/ci/travis/dev_pkgs_del_old.py
28 | else
29 | echo "only executing cleanup script for py35 && npy111 && linux"
30 | fi
31 |
32 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_format_registry.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | from chainsaw.data.util.fileformat_registry import FileFormatRegistry, FileFormatUnsupported
4 | from chainsaw.data._base.datasource import DataSource
5 |
6 |
7 | class TestFormatRegistry(unittest.TestCase):
8 | def test_multiple_calls_register(self):
9 | with self.assertRaises(RuntimeError) as exc:
10 | @FileFormatRegistry.register(".foo")
11 | @FileFormatRegistry.register(".bar")
12 | class test_src():
13 | pass
14 | self.assertIn("only once", exc.exception.args[0])
15 |
16 | def test_correct_reader_by_ext(self):
17 |
18 | @FileFormatRegistry.register(".foo")
19 | class test_src_foo(DataSource):
20 | pass
21 |
22 | @FileFormatRegistry.register(".bar")
23 | class test_src_bar(DataSource):
24 | pass
25 |
26 | self.assertEqual(FileFormatRegistry[".foo"], test_src_foo)
27 | self.assertEqual(FileFormatRegistry[".bar"], test_src_bar)
28 |
29 | def test_not_supported(self):
30 | ext = '.imsurethisisnotvalid'
31 | with self.assertRaises(FileFormatUnsupported) as exc:
32 | FileFormatRegistry[ext]
33 | self.assertIn(ext, exc.exception.args[0])
34 | self.assertIn('not supported', exc.exception.args[0])
35 |
--------------------------------------------------------------------------------
/chainsaw/transform/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | r"""
21 | ===============================================================================
22 | transform - Transformation Utilities (:mod:`chainsaw.transform`)
23 | ===============================================================================
24 | .. currentmodule:: chainsaw.transform
25 |
26 | .. autosummary::
27 | :toctree: generated/
28 |
29 | PCA - principal components
30 | TICA - time independent components
31 | """
32 |
33 | from .pca import *
34 | from .tica import *
35 |
--------------------------------------------------------------------------------
/chainsaw/data/md/featurization/_base.py:
--------------------------------------------------------------------------------
1 | # This file is part of PyEMMA.
2 | #
3 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
4 | #
5 | # PyEMMA is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU Lesser General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU Lesser General Public License
16 | # along with this program. If not, see .
17 | '''
18 | Created on 15.02.2016
19 |
20 | @author: marscher
21 | '''
22 | from chainsaw.util.annotators import deprecated
23 |
24 |
25 | class Feature(object):
26 |
27 | @property
28 | def dimension(self):
29 | return self._dim
30 |
31 | @dimension.setter
32 | def dimension(self, val):
33 | assert isinstance(val, int)
34 | self._dim = val
35 |
36 | @deprecated('use transform(traj)')
37 | def map(self, traj):
38 | return self.transform(traj)
39 |
40 | def __eq__(self, other):
41 | return self.__hash__() == other.__hash__()
42 |
--------------------------------------------------------------------------------
/devtools/conda-recipe/run_test.py:
--------------------------------------------------------------------------------
1 |
2 | import subprocess
3 | import os
4 | import sys
5 | import shutil
6 | import re
7 |
8 | src_dir = os.getenv('SRC_DIR')
9 |
10 | test_pkg = 'chainsaw'
11 | cover_pkg = test_pkg
12 |
13 | # matplotlib headless backend
14 | with open('matplotlibrc', 'w') as fh:
15 | fh.write('backend: Agg')
16 |
17 |
18 | def coverage_report():
19 | fn = '.coverage'
20 | assert os.path.exists(fn)
21 | build_dir = os.getenv('TRAVIS_BUILD_DIR')
22 | dest = os.path.join(build_dir, fn)
23 | print( "copying coverage report to", dest)
24 | shutil.copy(fn, dest)
25 | assert os.path.exists(dest)
26 |
27 | # fix paths in .coverage file
28 | with open(dest, 'r') as fh:
29 | data = fh.read()
30 | match= '"/.+?/miniconda/envs/_test/lib/python.+?/site-packages/.+?/({test_pkg}/.+?)"'.format(test_pkg=test_pkg)
31 | repl = '"%s/\\1"' % build_dir
32 | data = re.sub(match, repl, data)
33 | os.unlink(dest)
34 | with open(dest, 'w+') as fh:
35 | fh.write(data)
36 |
37 | nose_run = "nosetests {test_pkg} -vv" \
38 | " --with-coverage --cover-inclusive --cover-package={cover_pkg}" \
39 | " --with-doctest --doctest-options=+NORMALIZE_WHITESPACE,+ELLIPSIS" \
40 | .format(test_pkg=test_pkg, cover_pkg=cover_pkg).split(' ')
41 |
42 | res = subprocess.call(nose_run)
43 |
44 |
45 | # move .coverage file to git clone on Travis CI
46 | if os.getenv('TRAVIS', False):
47 | coverage_report()
48 |
49 | sys.exit(res)
50 |
51 |
--------------------------------------------------------------------------------
/chainsaw/data/util/fileformat_registry.py:
--------------------------------------------------------------------------------
1 |
2 | class FileFormatUnsupported(Exception):
3 | """ No available reader is able to handle the given extension."""
4 |
5 |
6 | class FileFormatRegistry(object):
7 | """ Registry for trajectory file objects. """
8 | _readers = {}
9 |
10 | @classmethod
11 | def register(cls, *args):
12 | """ register the given class as Reader class for given extension(s). """
13 | def decorator(f):
14 | if hasattr(f, 'SUPPORTED_EXTENSIONS') and f.SUPPORTED_EXTENSIONS is not ():
15 | raise RuntimeError("please call register() only once per class!")
16 | extensions = tuple(args)
17 | cls._readers.update({e: f for e in extensions})
18 | f.SUPPORTED_EXTENSIONS = extensions
19 | return f
20 | return decorator
21 |
22 | @staticmethod
23 | def is_md_format(extension):
24 | import chainsaw.data.md as md
25 | if hasattr(md, 'FeatureReader'):
26 | from chainsaw.data.md import FeatureReader
27 | return extension in FeatureReader.SUPPORTED_EXTENSIONS
28 |
29 | return False
30 |
31 | def supported_extensions(self):
32 | return self._readers.keys()
33 |
34 | def __getitem__(self, item):
35 | try:
36 | return self._readers[item]
37 | except KeyError:
38 | raise FileFormatUnsupported("Extension {ext} is not supported by any reader.".format(ext=item))
39 |
40 | # singleton pattern
41 | FileFormatRegistry = FileFormatRegistry()
42 |
--------------------------------------------------------------------------------
/chainsaw/clustering/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | r"""
21 | ===============================================================================
22 | clustering - Algorithms (:mod:`chainsaw.clustering`)
23 | ===============================================================================
24 |
25 | .. currentmodule: chainsaw.clustering
26 |
27 | .. autosummary::
28 | :toctree: generated/
29 |
30 | AssignCenters
31 | KmeansClustering
32 | RegularSpaceClustering
33 | UniformTimeClustering
34 | """
35 |
36 | from .assign import AssignCenters
37 | from .kmeans import KmeansClustering
38 | from .kmeans import MiniBatchKmeansClustering
39 | from .regspace import RegularSpaceClustering
40 | from .uniform_time import UniformTimeClustering
--------------------------------------------------------------------------------
/devtools/ci/travis/dev_pkgs_del_old.py:
--------------------------------------------------------------------------------
1 | """
2 | Cleanup old development builds on Anaconda.org
3 |
4 | Assumes one has set 4 environment variables:
5 |
6 | 1. BINSTAR_TOKEN: token to authenticate with anaconda.org
7 | 2. DEV_BUILD_N_KEEP: int, how many builds to keep, delete oldest first.
8 | 3. ORGNAME: str, anaconda.org organisation/user
9 | 4. PACKGENAME: str, name of package to clean up
10 |
11 | author: Martin K. Scherer
12 | data: 20.4.16
13 | """
14 | from __future__ import print_function, absolute_import
15 |
16 | import os
17 |
18 | from binstar_client.utils import get_server_api
19 | from pkg_resources import parse_version
20 |
21 | token = os.environ['BINSTAR_TOKEN']
22 | org = os.environ['ORGNAME']
23 | pkg = os.environ['PACKAGENAME']
24 | n_keep = int(os.getenv('DEV_BUILD_N_KEEP', 10))
25 |
26 | b = get_server_api(token=token)
27 | package = b.package(org, pkg)
28 |
29 | # sort releases by version number, oldest first
30 | sorted_by_version = sorted(package['releases'],
31 | key=lambda rel: parse_version(rel['version']),
32 | reverse=True
33 | )
34 | to_delete = []
35 | print("Currently have {n} versions online. Going to remove {x}.".
36 | format(n=len(sorted_by_version), x=len(sorted_by_version) - n_keep))
37 |
38 | while len(sorted_by_version) > n_keep:
39 | to_delete.append(sorted_by_version.pop())
40 |
41 |
42 | # remove old releases from anaconda.org
43 | for rel in to_delete:
44 | version = rel['version']
45 | print("removing {version}".format(version=version))
46 | b.remove_release(org, pkg, version)
47 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_acf.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 |
21 | from __future__ import absolute_import
22 | import unittest
23 | import numpy as np
24 |
25 | from chainsaw.acf import acf
26 | from six.moves import range
27 |
28 |
29 | class TestACF(unittest.TestCase):
30 | def test(self):
31 | # generate some data
32 | data = np.random.rand(100, 3)
33 |
34 | testacf = acf(data)
35 |
36 | # direct computation of acf (single trajectory, three observables)
37 | N = data.shape[0]
38 | refacf = np.zeros(data.shape)
39 | meanfree = data - np.mean(data, axis=0)
40 | padded = np.concatenate((meanfree, np.zeros(data.shape)), axis=0)
41 | for tau in range(N):
42 | refacf[tau] = (padded[0:N, :]*padded[tau:N+tau, :]).sum(axis=0)/(N-tau)
43 | refacf /= refacf[0] # normalize
44 |
45 | np.testing.assert_allclose(refacf, testacf)
46 |
47 | if __name__ == "__main__":
48 | unittest.main()
--------------------------------------------------------------------------------
/devtools/ci/appveyor/run_with_env.cmd:
--------------------------------------------------------------------------------
1 | :: To build extensions for 64 bit Python 3, we need to configure environment
2 | :: variables to use the MSVC 2010 C++ compilers from GRMSDKX_EN_DVD.iso of:
3 | :: MS Windows SDK for Windows 7 and .NET Framework 4 (SDK v7.1)
4 | ::
5 | :: To build extensions for 64 bit Python 2, we need to configure environment
6 | :: variables to use the MSVC 2008 C++ compilers from GRMSDKX_EN_DVD.iso of:
7 | :: MS Windows SDK for Windows 7 and .NET Framework 3.5 (SDK v7.0)
8 | ::
9 | :: 32 bit builds do not require specific environment configurations.
10 | ::
11 | :: Note: this script needs to be run with the /E:ON and /V:ON flags for the
12 | :: cmd interpreter, at least for (SDK v7.0)
13 | ::
14 | :: More details at:
15 | :: https://github.com/cython/cython/wiki/64BitCythonExtensionsOnWindows
16 | :: http://stackoverflow.com/a/13751649/163740
17 | ::
18 | :: Author: Olivier Grisel
19 | :: License: CC0 1.0 Universal: http://creativecommons.org/publicdomain/zero/1.0/
20 | @ECHO OFF
21 |
22 | SET COMMAND_TO_RUN=%*
23 | SET WIN_SDK_ROOT=C:\Program Files\Microsoft SDKs\Windows
24 |
25 | SET MAJOR_PYTHON_VERSION="%CONDA_PY:~0,1%"
26 | IF %MAJOR_PYTHON_VERSION% == "2" (
27 | SET WINDOWS_SDK_VERSION="v7.0"
28 | ) ELSE IF %MAJOR_PYTHON_VERSION% == "3" (
29 | SET WINDOWS_SDK_VERSION="v7.1"
30 | ) ELSE (
31 | ECHO Unsupported Python version: "%MAJOR_PYTHON_VERSION%"
32 | EXIT 1
33 | )
34 |
35 | IF "%ARCH%"=="64" (
36 | ECHO Configuring Windows SDK %WINDOWS_SDK_VERSION% for Python %MAJOR_PYTHON_VERSION% on a 64 bit architecture
37 | SET DISTUTILS_USE_SDK=1
38 | SET MSSdk=1
39 | "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Setup\WindowsSdkVer.exe" -q -version:%WINDOWS_SDK_VERSION%
40 | "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Bin\SetEnv.cmd" /x64 /release
41 | ECHO Executing: %COMMAND_TO_RUN%
42 | call %COMMAND_TO_RUN% || EXIT 1
43 | ) ELSE (
44 | ECHO Using default MSVC build environment for 32 bit architecture
45 | ECHO Executing: %COMMAND_TO_RUN%
46 | call %COMMAND_TO_RUN% || EXIT 1
47 | )
48 |
--------------------------------------------------------------------------------
/chainsaw/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | r"""
21 | ===============================================================================
22 | data - Data and input/output utilities (:mod:`chainsaw.data`)
23 | ===============================================================================
24 |
25 | .. currentmodule: chainsaw.data
26 |
27 | Order parameters
28 | ================
29 |
30 | .. autosummary::
31 | :toctree: generated/
32 |
33 | MDFeaturizer - selects and computes features from MD trajectories
34 | CustomFeature - define arbitrary function to extract features
35 |
36 | Reader
37 | ======
38 |
39 | .. autosummary::
40 | :toctree: generated/
41 |
42 | FeatureReader - reads features via featurizer
43 | NumPyFileReader - reads numpy files
44 | PyCSVReader - reads tabulated ascii files
45 | DataInMemory - used if data is already available in mem
46 |
47 | """
48 | from chainsaw.data.md.feature_reader import FeatureReader
49 |
50 | from .data_in_memory import DataInMemory
51 | from .numpy_filereader import NumPyFileReader
52 | from .py_csv_reader import PyCSVReader
53 | from .util.reader_utils import create_file_reader
54 |
55 |
56 |
57 | # TODO: if mdtraj avail, import FEatur3REader to register md extensions.
58 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_uniform_time.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Created on 09.04.2015
22 |
23 | @author: marscher
24 | '''
25 |
26 | from __future__ import absolute_import
27 | import unittest
28 |
29 | import numpy as np
30 |
31 | from chainsaw import api
32 | from chainsaw.data.data_in_memory import DataInMemory
33 |
34 |
35 | class TestUniformTimeClustering(unittest.TestCase):
36 |
37 | def test_1d(self):
38 | x = np.random.random(1000)
39 | reader = DataInMemory(x)
40 |
41 | k = 2
42 | c = api.cluster_uniform_time(k=k)
43 |
44 | c.data_producer = reader
45 | c.parametrize()
46 |
47 | def test_2d(self):
48 | x = np.random.random((300, 3))
49 | reader = DataInMemory(x)
50 |
51 | k = 2
52 | c = api.cluster_uniform_time(k=k)
53 |
54 | c.data_producer = reader
55 | c.parametrize()
56 |
57 | def test_2d_skip(self):
58 | x = np.random.random((300, 3))
59 | reader = DataInMemory(x)
60 |
61 | k = 2
62 | c = api.cluster_uniform_time(k=k, skip=100)
63 |
64 | c.data_producer = reader
65 | c.parametrize()
66 |
67 | def test_big_k(self):
68 | x = np.random.random((300, 3))
69 | reader = DataInMemory(x)
70 | k=151
71 | c = api.cluster_uniform_time(k=k)
72 |
73 | c.data_producer = reader
74 | c.parametrize()
75 |
76 |
77 | if __name__ == "__main__":
78 | unittest.main()
--------------------------------------------------------------------------------
/chainsaw/util/contexts.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on 04.01.2016
3 |
4 | @author: marscher
5 | '''
6 | from contextlib import contextmanager
7 | import random
8 |
9 | import numpy as np
10 |
11 |
12 | class conditional(object):
13 | """Wrap another context manager and enter it only if condition is true.
14 | """
15 |
16 | def __init__(self, condition, contextmanager):
17 | self.condition = condition
18 | self.contextmanager = contextmanager
19 |
20 | def __enter__(self):
21 | if self.condition:
22 | return self.contextmanager.__enter__()
23 |
24 | def __exit__(self, *args):
25 | if self.condition:
26 | return self.contextmanager.__exit__(*args)
27 |
28 |
29 | @contextmanager
30 | def numpy_random_seed(seed=42):
31 | """ sets the random seed of numpy within the context.
32 |
33 | Example
34 | -------
35 | >>> import numpy as np
36 | >>> with numpy_random_seed(seed=0):
37 | ... np.random.randint(1000)
38 | 684
39 | """
40 | old_state = np.random.get_state()
41 | np.random.seed(seed)
42 | try:
43 | yield
44 | finally:
45 | np.random.set_state(old_state)
46 |
47 |
48 | @contextmanager
49 | def random_seed(seed=42):
50 | """ sets the random seed of Python within the context.
51 |
52 | Example
53 | -------
54 | >>> import random
55 | >>> with random_seed(seed=0):
56 | ... random.randint(0, 1000) # doctest: +SKIP
57 | 864
58 | """
59 | old_state = random.getstate()
60 | random.seed(seed)
61 | try:
62 | yield
63 | finally:
64 | random.setstate(old_state)
65 |
66 |
67 | @contextmanager
68 | def settings(**kwargs):
69 | """ apply given PyEMMA config values temporarily within the given context."""
70 | from chainsaw import config
71 | # validate:
72 | valid_keys = config.keys()
73 | for k in kwargs.keys():
74 | if k not in valid_keys:
75 | raise ValueError("not a valid settings: {key}".format(key=k))
76 |
77 | old_settings = {}
78 | for k, v in kwargs.items():
79 | old_settings[k] = getattr(config, k)
80 | setattr(config, k, v)
81 |
82 | yield
83 |
84 | # restore old settings
85 | for k, v in old_settings.items():
86 | setattr(config, k, v)
87 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_mini_batch_kmeans.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import absolute_import
20 | import unittest
21 | from unittest import TestCase
22 | import numpy as np
23 | from chainsaw.api import cluster_mini_batch_kmeans
24 |
25 |
26 | class TestMiniBatchKmeans(TestCase):
27 | def test_3gaussian_1d_singletraj(self):
28 | # generate 1D data from three gaussians
29 | X = [np.random.randn(200) - 2.0,
30 | np.random.randn(300),
31 | np.random.randn(400) + 2.0]
32 | X = np.hstack(X)
33 | kmeans = cluster_mini_batch_kmeans(X, batch_size=0.5, k=100, max_iter=10000)
34 | cc = kmeans.clustercenters
35 | assert (np.any(cc < 1.0))
36 | assert (np.any((cc > -1.0) * (cc < 1.0)))
37 | assert (np.any(cc > -1.0))
38 |
39 | def test_3gaussian_2d_multitraj(self):
40 | # generate 1D data from three gaussians
41 | X1 = np.zeros((200, 2))
42 | X1[:, 0] = np.random.randn(200) - 2.0
43 | X2 = np.zeros((300, 2))
44 | X2[:, 0] = np.random.randn(300)
45 | X3 = np.zeros((400, 2))
46 | X3[:, 0] = np.random.randn(400) + 2.0
47 | X = [X1, X2, X3]
48 | kmeans = cluster_mini_batch_kmeans(X, batch_size=0.5, k=100, max_iter=10000)
49 | cc = kmeans.clustercenters
50 | assert (np.any(cc < 1.0))
51 | assert (np.any((cc > -1.0) * (cc < 1.0)))
52 | assert (np.any(cc > -1.0))
53 |
54 | if __name__ == '__main__':
55 | unittest.main()
--------------------------------------------------------------------------------
/chainsaw/util/exceptions.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | '''
20 | Created on May 26, 2014
21 |
22 | @author: marscher
23 | '''
24 |
25 |
26 | class SpectralWarning(RuntimeWarning):
27 | pass
28 |
29 |
30 | class ImaginaryEigenValueWarning(SpectralWarning):
31 | pass
32 |
33 |
34 | class PrecisionWarning(RuntimeWarning):
35 | r"""
36 | This warning indicates that some operation in your code leads
37 | to a conversion of datatypes, which involves a loss/gain in
38 | precision.
39 |
40 | """
41 | pass
42 |
43 |
44 | class NotConvergedWarning(RuntimeWarning):
45 | r"""
46 | This warning indicates that some iterative procedure has not
47 | converged or reached the maximum number of iterations implemented
48 | as a safe guard to prevent arbitrary many iterations in loops with
49 | a conditional termination criterion.
50 |
51 | """
52 | pass
53 |
54 |
55 | class EfficiencyWarning(UserWarning):
56 | r"""Some operation or input data leads to a lack of efficiency"""
57 | pass
58 |
59 |
60 | class ParserWarning(UserWarning):
61 | """ Some user defined variable could not be parsed and is ignored/replaced. """
62 | pass
63 |
64 |
65 | class ConfigDirectoryException(Exception):
66 | """ Some operation with the configuration directory went wrong. """
67 | pass
68 |
69 |
70 | class Chainsaw_DeprecationWarning(UserWarning):
71 | """You are using a feature, which will be removed in a future release. You have been warned!"""
72 | pass
--------------------------------------------------------------------------------
/chainsaw/util/change_notification.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | '''
20 | Created on 14.01.2016
21 |
22 | @author: marscher
23 | '''
24 |
25 |
26 | def inform_children_upon_change(f):
27 | """ decorator to call interface method '_stream_upon_change' of NotifyOnChangesMixin
28 | instances """
29 |
30 | def _notify(self, *args, **kw):
31 | # first call the decorated function, then inform about the change.
32 | res = f(self, *args, **kw)
33 | self._stream_on_change()
34 | return res
35 |
36 | return _notify
37 |
38 |
39 | class NotifyOnChangesMixIn(object):
40 | #### interface to handle events
41 |
42 | @property
43 | def _stream_children(self):
44 | if not hasattr(self, "_stream_children_list"):
45 | self._stream_children_list = []
46 | return self._stream_children_list
47 |
48 | def _stream_register_child(self, data_producer):
49 | """ should be called upon setting of data_producer """
50 | self._stream_children.append(data_producer)
51 |
52 | def _stream_unregister_child(self, child):
53 | try:
54 | self._stream_children.remove(child)
55 | except ValueError:
56 | print("should not happen")
57 |
58 | def _stream_on_change(self):
59 | pass
60 |
61 | def _stream_inform_children_param_change(self):
62 | """ will inform all children about a parameter change in general """
63 | for c in self._stream_children:
64 | c._stream_on_change()
65 |
--------------------------------------------------------------------------------
/chainsaw/util/files.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Created on 17.02.2014
22 |
23 | @author: marscher
24 | '''
25 | from __future__ import absolute_import, print_function
26 |
27 | import os
28 | import errno
29 | import tempfile
30 | import shutil
31 |
32 |
33 | def mkdir_p(path):
34 | try:
35 | os.makedirs(path)
36 | except OSError as exc: # Python >2.5
37 | if exc.errno == errno.EEXIST and os.path.isdir(path):
38 | pass
39 | else:
40 | raise
41 |
42 |
43 | class TemporaryDirectory(object):
44 | """Create and return a temporary directory. This has the same
45 | behavior as mkdtemp but can be used as a context manager. For
46 | example:
47 |
48 | Examples
49 | --------
50 | >>> import os
51 | >>> with TemporaryDirectory() as tmp:
52 | ... path = os.path.join(tmp, "myfile.dat")
53 | ... fh = open(path, 'w')
54 | ... _ = fh.write('hello world')
55 | ... fh.close()
56 |
57 | Upon exiting the context, the directory and everything contained
58 | in it are removed.
59 | """
60 |
61 | def __init__(self, prefix='', suffix='', dir=None):
62 | self.prefix = prefix
63 | self.suffix = suffix
64 | self.dir = dir
65 | self.tmpdir = None
66 |
67 | def __enter__(self):
68 | self.tmpdir = tempfile.mkdtemp(suffix=self.suffix, prefix=self.prefix,
69 | dir=self.dir)
70 | return self.tmpdir
71 |
72 | def __exit__(self, *args):
73 | shutil.rmtree(self.tmpdir, ignore_errors=True)
--------------------------------------------------------------------------------
/devtools/ci/jenkins/update_versions_json.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # author: marscher
3 | # purpose: update version.json file on new software release.
4 | from __future__ import print_function
5 | import json
6 | import sys
7 |
8 | from argparse import ArgumentParser
9 |
10 | from six.moves.urllib.request import urlopen
11 | from distutils.version import LooseVersion as parse
12 |
13 |
14 | def hash_(self):
15 | return hash(self.vstring)
16 | parse.__hash__ = hash_
17 |
18 |
19 | def make_version_dict(URL, version, url_prefix='v', latest=False):
20 | return {'version': version,
21 | # git tags : vx.y.z
22 | 'url': URL + '/' + url_prefix + version,
23 | 'latest': latest}
24 |
25 |
26 | def find_latest(versions):
27 | for v in versions:
28 | if v['latest'] == True:
29 | return v
30 |
31 |
32 | def main(argv=None):
33 | '''Command line options.'''
34 |
35 | if argv is None:
36 | argv = sys.argv
37 | else:
38 | sys.argv.extend(argv)
39 |
40 | parser = ArgumentParser()
41 | parser.add_argument('-u', '--url', dest='url', required=True, help="base url (has to contain versions json)")
42 | parser.add_argument('-o', '--output', dest='output')
43 | parser.add_argument('-a', '--add_version', dest='version')
44 | parser.add_argument('-v', '--verbose', dest='verbose', action='store_true')
45 | parser.add_argument('-l', '--latest-version', dest='latest', action='store_true')
46 |
47 | args = parser.parse_args()
48 |
49 | URL = args.url
50 | # get dict
51 | versions = json.load(urlopen(URL + '/versions.json'))
52 | # add new version
53 | if args.version:
54 | versions.append(make_version_dict(URL, args.version))
55 |
56 | # create Version objects to compare them
57 | version_objs = [parse(s['version']) for s in versions]
58 |
59 | # unify and sort
60 | version_objs = set(version_objs)
61 | version_objs = sorted(list(version_objs))
62 |
63 | versions = [make_version_dict(URL, str(v)) for v in version_objs if v != 'devel']
64 |
65 | # last element should be the highest version
66 | versions[-1]['latest'] = True
67 | versions.append(make_version_dict(URL, 'devel', '', False))
68 |
69 | if args.verbose:
70 | print ("new versions json:", versions)
71 |
72 | if args.latest:
73 | print(find_latest(versions)['version'])
74 | return 0
75 |
76 | if args.output:
77 | with open(args.output, 'w') as v:
78 | json.dump(versions, v)
79 | v.flush()
80 |
81 | if __name__ == '__main__':
82 | sys.exit(main())
83 |
--------------------------------------------------------------------------------
/chainsaw/util/indices.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import absolute_import
20 | import numpy as np
21 |
22 |
23 | def combinations(seq, k):
24 | """ Return j length subsequences of elements from the input iterable.
25 |
26 | This version uses Numpy/Scipy and should be preferred over itertools. It avoids
27 | the creation of all intermediate Python objects.
28 |
29 | Examples
30 | --------
31 |
32 | >>> import numpy as np
33 | >>> from itertools import combinations as iter_comb
34 | >>> x = np.arange(3)
35 | >>> c1 = combinations(x, 2)
36 | >>> print(c1) # doctest: +NORMALIZE_WHITESPACE
37 | [[0 1]
38 | [0 2]
39 | [1 2]]
40 | >>> c2 = np.array(tuple(iter_comb(x, 2)))
41 | >>> print(c2) # doctest: +NORMALIZE_WHITESPACE
42 | [[0 1]
43 | [0 2]
44 | [1 2]]
45 | """
46 | from itertools import combinations as _combinations, chain
47 | from scipy.misc import comb
48 |
49 | count = comb(len(seq), k, exact=True)
50 | res = np.fromiter(chain.from_iterable(_combinations(seq, k)),
51 | int, count=count*k)
52 | return res.reshape(-1, k)
53 |
54 |
55 | def product(*arrays):
56 | """ Generate a cartesian product of input arrays.
57 |
58 | Parameters
59 | ----------
60 | arrays : list of array-like
61 | 1-D arrays to form the cartesian product of.
62 |
63 | Returns
64 | -------
65 | out : ndarray
66 | 2-D array of shape (M, len(arrays)) containing cartesian products
67 | formed of input arrays.
68 |
69 | """
70 | arrays = [np.asarray(x) for x in arrays]
71 | shape = (len(x) for x in arrays)
72 | dtype = arrays[0].dtype
73 |
74 | ix = np.indices(shape)
75 | ix = ix.reshape(len(arrays), -1).T
76 |
77 | out = np.empty_like(ix, dtype=dtype)
78 |
79 | for n, _ in enumerate(arrays):
80 | out[:, n] = arrays[n][ix[:, n]]
81 |
82 | return out
83 |
--------------------------------------------------------------------------------
/chainsaw/tests/util.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | import numpy as np
3 | import mdtraj
4 | import pkg_resources
5 | import os
6 | import logging
7 |
8 |
9 | class MockLoggingHandler(logging.Handler):
10 | """Mock logging handler to check for expected logs."""
11 |
12 | def __init__(self, *args, **kwargs):
13 | self.reset()
14 | logging.Handler.__init__(self, *args, **kwargs)
15 |
16 | def emit(self, record):
17 | self.messages[record.levelname.lower()].append(record.getMessage())
18 |
19 | def reset(self):
20 | self.messages = {
21 | 'debug': [],
22 | 'info': [],
23 | 'warning': [],
24 | 'error': [],
25 | 'critical': [],
26 | }
27 |
28 | def _setup_testing():
29 | # setup function for testing
30 | from chainsaw import config
31 | # do not cache trajectory info in user directory (temp traj files)
32 | config.use_trajectory_lengths_cache = False
33 | config.show_progress_bars = False
34 |
35 | def _monkey_patch_testing_apply_setting():
36 | """ this monkey patches the init methods of unittest.TestCase and doctest.DocTestFinder,
37 | in order to apply internal settings. """
38 | import unittest
39 | import doctest
40 | _old_init = unittest.TestCase.__init__
41 | def _new_init(self, *args, **kwargs):
42 | _old_init(self, *args, **kwargs)
43 | _setup_testing()
44 |
45 | unittest.TestCase.__init__ = _new_init
46 |
47 | _old_init_doc_test_finder = doctest.DocTestFinder.__init__
48 |
49 | def _patched_init(self, *args, **kw):
50 | _setup_testing()
51 | _old_init_doc_test_finder(self, *args, **kw)
52 |
53 | doctest.DocTestFinder.__init__ = _patched_init
54 |
55 |
56 | def get_bpti_test_data():
57 | import pkg_resources
58 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep
59 | from glob import glob
60 | xtcfiles = glob(path + '/bpti_0*.xtc')
61 | pdbfile = os.path.join(path, 'bpti_ca.pdb')
62 | assert xtcfiles, xtcfiles
63 | assert pdbfile, pdbfile
64 |
65 | return xtcfiles, pdbfile
66 |
67 |
68 | def get_top():
69 | return pkg_resources.resource_filename(__name__, 'data/test.pdb')
70 |
71 |
72 | def create_traj(top=None, format='.xtc', dir=None, length=1000, start=0):
73 | trajfile = tempfile.mktemp(suffix=format, dir=dir)
74 | xyz = np.arange(start*3*3, (start+length) * 3 * 3)
75 | xyz = xyz.reshape((-1, 3, 3))
76 | if top is None:
77 | top = get_top()
78 |
79 | t = mdtraj.load(top)
80 | t.xyz = xyz
81 | t.unitcell_vectors = np.array(length * [[0, 0, 1], [0, 1, 0], [1, 0, 0]]).reshape(length, 3, 3)
82 | t.time = np.arange(length)
83 | t.save(trajfile)
84 |
85 | return trajfile, xyz, length
86 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_cluster_samples.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | """
21 | Test the save_trajs function of the coordinates API by comparing
22 | the direct, sequential retrieval of frames via mdtraj.load_frame() vs
23 | the retrival via save_trajs
24 | @author: gph82, clonker
25 | """
26 |
27 | from __future__ import absolute_import
28 |
29 | import unittest
30 |
31 | import numpy as np
32 | import chainsaw as coor
33 |
34 |
35 | class TestClusterSamples(unittest.TestCase):
36 |
37 | def setUp(self):
38 | self.input_trajs = [[0,1,2],
39 | [3,4,5],
40 | [6,7,8],
41 | [0,1,2],
42 | [3,4,5],
43 | [6,7,8]]
44 | self.cluster_obj = coor.cluster_regspace(data=self.input_trajs, dmin=.5)
45 |
46 | def test_index_states(self):
47 | # Test that the catalogue is being set up properly
48 |
49 | # The assingment-catalogue is easy to see from the above dtrajs
50 | ref = [[[0,0],[3,0]], # appearances of the 1st cluster
51 | [[0,1],[3,1]], # appearances of the 2nd cluster
52 | [[0,2],[3,2]], # appearances of the 3rd cluster
53 | [[1,0],[4,0]], # .....
54 | [[1,1],[4,1]],
55 | [[1,2],[4,2]],
56 | [[2,0],[5,0]],
57 | [[2,1],[5,1]],
58 | [[2,2],[5,2]],
59 | ]
60 |
61 | for cc in np.arange(self.cluster_obj.n_clusters):
62 | assert np.allclose(self.cluster_obj.index_clusters[cc], ref[cc])
63 |
64 | def test_sample_indexes_by_state(self):
65 | samples = self.cluster_obj.sample_indexes_by_cluster(np.arange(self.cluster_obj.n_clusters), 10)
66 |
67 | # For each sample, check that you're actually retrieving the i-th center
68 | for ii, isample in enumerate(samples):
69 | assert np.in1d([self.cluster_obj.dtrajs[pair[0]][pair[1]] for pair in isample],ii).all()
70 |
71 |
72 | if __name__ == "__main__":
73 | unittest.main()
--------------------------------------------------------------------------------
/chainsaw/_ext/variational_estimators/lagged_correlation.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | __author__ = 'noe'
4 |
5 | import numpy as np
6 |
7 | class LaggedCorrelation(object):
8 |
9 | def __init__(self, output_dimension, tau=1):
10 | """ Computes correlation matrices C0 and Ctau from a bunch of trajectories
11 |
12 | Parameters
13 | ----------
14 | output_dimension: int
15 | Number of basis functions.
16 | tau: int
17 | Lag time
18 |
19 | """
20 | self.tau = tau
21 | self.output_dimension = output_dimension
22 | # Initialize the two correlation matrices:
23 | self.Ct = np.zeros((self.output_dimension, self.output_dimension))
24 | self.C0 = np.zeros((self.output_dimension, self.output_dimension))
25 | # Create counter for the frames used for C0, Ct:
26 | self.nC0 = 0
27 | self.nCt = 0
28 |
29 | def add(self, X):
30 | """ Adds trajectory to the running estimate for computing mu, C0 and Ct:
31 |
32 | Parameters
33 | ----------
34 | X: ndarray (T,N)
35 | basis function trajectory of T time steps for N basis functions.
36 |
37 | """
38 | # Raise an error if output dimension is wrong:
39 | if not (X.shape[1] == self.output_dimension):
40 | raise Exception("Number of basis functions is incorrect."+
41 | "Got %d, expected %d."%(X.shape[1],
42 | self.output_dimension))
43 | # Print message if number of time steps is too small:
44 | if X.shape[0] <= self.tau:
45 | raise ValueError("Number of time steps is too small.")
46 |
47 | # Get the time-lagged data:
48 | Y1 = X[self.tau:,:]
49 | # Remove the last tau frames from X:
50 | Y2 = X[:-self.tau,:]
51 | # Get the number of time steps in this trajectory:
52 | TX = 1.0*Y1.shape[0]
53 | # Update time-lagged correlation matrix:
54 | self.Ct += np.dot(Y1.T,Y2)
55 | self.nCt += TX
56 | # Update the instantaneous correlation matrix:
57 | self.C0 += np.dot(Y1.T,Y1) + np.dot(Y2.T,Y2)
58 | self.nC0 += 2*TX
59 |
60 | def GetC0(self):
61 | """ Returns the current estimate of C0:
62 |
63 | Returns
64 | -------
65 | C0: ndarray (N,N)
66 | time instantaneous correlation matrix of N basis function.
67 |
68 | """
69 | return 0.5*(self.C0 + self.C0.T)/(self.nC0 - 1)
70 |
71 | def GetCt(self):
72 | """ Returns the current estimate of Ctau
73 |
74 | Returns
75 | -------
76 | Ct: ndarray (N,N)
77 | time lagged correlation matrix of N basis function.
78 |
79 | """
80 | return 0.5*(self.Ct + self.Ct.T)/(self.nCt - 1)
81 |
--------------------------------------------------------------------------------
/chainsaw/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | r"""
21 | .. currentmodule:: chainsaw
22 |
23 | User API
24 | ========
25 |
26 | **Trajectory input/output and featurization**
27 |
28 | .. autosummary::
29 | :toctree: generated/
30 |
31 | featurizer
32 | load
33 | source
34 | pipeline
35 | discretizer
36 | save_traj
37 | save_trajs
38 |
39 | **Coordinate and feature transformations**
40 |
41 | .. autosummary::
42 | :toctree: generated/
43 |
44 | pca
45 | tica
46 |
47 | **Clustering Algorithms**
48 |
49 | .. autosummary::
50 | :toctree: generated/
51 |
52 | cluster_kmeans
53 | cluster_mini_batch_kmeans
54 | cluster_regspace
55 | cluster_uniform_time
56 | assign_to_centers
57 |
58 | Classes
59 | =======
60 | **Coordinate classes** encapsulating complex functionality. You don't need to
61 | construct these classes yourself, as this is done by the user API functions above.
62 | Find here a documentation how to extract features from them.
63 |
64 | **I/O and Featurization**
65 |
66 | .. autosummary::
67 | :toctree: generated/
68 |
69 | data.MDFeaturizer
70 | data.CustomFeature
71 |
72 | **Transformation estimators**
73 |
74 | .. autosummary::
75 | :toctree: generated/
76 |
77 | transform.PCA
78 | transform.TICA
79 |
80 | **Clustering algorithms**
81 |
82 | .. autosummary::
83 | :toctree: generated/
84 |
85 | clustering.KmeansClustering
86 | clustering.MiniBatchKmeansClustering
87 | clustering.RegularSpaceClustering
88 | clustering.UniformTimeClustering
89 |
90 | **Transformers**
91 |
92 | .. autosummary::
93 | :toctree: generated/
94 |
95 | transform.transformer.StreamingTransformer
96 | pipelines.Pipeline
97 |
98 | **Discretization**
99 |
100 | .. autosummary::
101 | :toctree: generated/
102 |
103 | clustering.AssignCenters
104 |
105 |
106 | """
107 | from .util import config
108 | from .api import *
109 |
110 | from ._version import get_versions
111 | __version__ = get_versions()['version']
112 | version = __version__
113 | del get_versions
114 |
115 | from chainsaw.tests.util import _monkey_patch_testing_apply_setting
116 | _monkey_patch_testing_apply_setting()
117 | del _monkey_patch_testing_apply_setting
--------------------------------------------------------------------------------
/chainsaw/util/stat.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 |
21 | from __future__ import absolute_import
22 | import numpy as np
23 | from chainsaw.util.annotators import deprecated
24 | from six.moves import zip
25 |
26 | __author__ = 'Fabian Paul'
27 | __all__ = ['histogram']
28 |
29 |
30 | def histogram(transform, dimensions, nbins):
31 | '''Computes the N-dimensional histogram of the transformed data.
32 |
33 | Parameters
34 | ----------
35 | transform : chainsaw.transfrom.Transformer object
36 | transform that provides the input data
37 | dimensions : tuple of indices
38 | indices of the dimensions you want to examine
39 | nbins : tuple of ints
40 | number of bins along each dimension
41 |
42 | Returns
43 | -------
44 | counts : (bins[0],bins[1],...) ndarray of ints
45 | counts compatible with pyplot.pcolormesh and pyplot.bar
46 | edges : list of (bins[i]) ndarrays
47 | bin edges compatible with pyplot.pcolormesh and pyplot.bar,
48 | see below.
49 |
50 | Examples
51 | --------
52 |
53 | >>> import matplotlib.pyplot as plt # doctest: +SKIP
54 |
55 | Only for ipython notebook
56 | >> %matplotlib inline # doctest: +SKIP
57 |
58 | >>> counts, edges=histogram(transform, dimensions=(0,1), nbins=(20, 30)) # doctest: +SKIP
59 | >>> plt.pcolormesh(edges[0], edges[1], counts.T) # doctest: +SKIP
60 |
61 | >>> counts, edges=histogram(transform, dimensions=(1,), nbins=(50,)) # doctest: +SKIP
62 | >>> plt.bar(edges[0][:-1], counts, width=edges[0][1:]-edges[0][:-1]) # doctest: +SKIP
63 | '''
64 | maximum = np.ones(len(dimensions)) * (-np.inf)
65 | minimum = np.ones(len(dimensions)) * np.inf
66 | # compute min and max
67 | for _, chunk in transform:
68 | maximum = np.max(
69 | np.vstack((
70 | maximum,
71 | np.max(chunk[:, dimensions], axis=0))),
72 | axis=0)
73 | minimum = np.min(
74 | np.vstack((
75 | minimum,
76 | np.min(chunk[:, dimensions], axis=0))),
77 | axis=0)
78 | # define bins
79 | bins = [np.linspace(m, M, num=n)
80 | for m, M, n in zip(minimum, maximum, nbins)]
81 | res = np.zeros(np.array(nbins) - 1)
82 | # compute actual histogram
83 | for _, chunk in transform:
84 | part, _ = np.histogramdd(chunk[:, dimensions], bins=bins)
85 | res += part
86 | return res, bins
--------------------------------------------------------------------------------
/chainsaw/clustering/include/clustering.h:
--------------------------------------------------------------------------------
1 | /* * This file is part of PyEMMA.
2 | *
3 | * Copyright (c) 2015, 2014 Computational Molecular Biology Group
4 | *
5 | * PyEMMA is free software: you can redistribute it and/or modify
6 | * it under the terms of the GNU Lesser General Public License as published by
7 | * the Free Software Foundation, either version 3 of the License, or
8 | * (at your option) any later version.
9 | *
10 | * This program is distributed in the hope that it will be useful,
11 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | * GNU General Public License for more details.
14 | *
15 | * You should have received a copy of the GNU Lesser General Public License
16 | * along with this program. If not, see .
17 | */
18 |
19 | #ifndef _CLUSTERING_H_
20 | #define _CLUSTERING_H_
21 | #ifdef __cplusplus
22 | extern "C" {
23 | #endif
24 |
25 | #include
26 | #define PY_ARRAY_UNIQUE_SYMBOL pyemma_clustering_ARRAY_API
27 | #include
28 |
29 | #ifdef _OPENMP
30 | #include
31 | #endif /* _OPENMP */
32 |
33 | #include
34 | #include
35 | #include
36 | #include
37 | #include
38 | #include
39 | #include
40 |
41 |
42 | #ifndef NULL
43 | #warning null defined manually...
44 | #define NULL 0x0
45 | #endif
46 |
47 | #if defined(__GNUC__) && ((__GNUC__ > 3) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))
48 | # define SKP_restrict __restrict
49 | #elif defined(_MSC_VER) && _MSC_VER >= 1400
50 | # define SKP_restrict __restrict
51 | #else
52 | # define SKP_restrict
53 | #endif
54 |
55 | #define ASSIGN_SUCCESS 0
56 | #define ASSIGN_ERR_NO_MEMORY 1
57 | #define ASSIGN_ERR_INVALID_METRIC 2
58 |
59 | static char ASSIGN_USAGE[] = "assign(chunk, centers, dtraj, metric)\n"\
60 | "Assigns frames in `chunk` to the closest cluster centers.\n"\
61 | "\n"\
62 | "Parameters\n"\
63 | "----------\n"\
64 | "chunk : (N,M) C-style contiguous and behaved ndarray of np.float32\n"\
65 | " (input) array of N frames, each frame having dimension M\n"\
66 | "centers : (M,K) ndarray-like of np.float32\n"\
67 | " (input) Non-empty array-like of cluster centers.\n"\
68 | "dtraj : (N) ndarray of np.int64\n"\
69 | " (output) discretized trajectory\n"\
70 | " dtraj[i]=argmin{ d(chunk[i,:],centers[j,:]) | j in 0...(K-1) }\n"\
71 | " where d is the metric that is specified with the argument `metric`.\n"\
72 | "metric : string\n"\
73 | " (input) One of \"euclidean\" or \"minRMSD\" (case sensitive).\n"\
74 | "\n"\
75 | "Returns \n"\
76 | "-------\n"\
77 | "None\n"\
78 | "\n"\
79 | "Note\n"\
80 | "----\n"\
81 | "This function uses the minRMSD implementation of mdtraj.";
82 |
83 | // euclidean metric
84 | float euclidean_distance(float *SKP_restrict a, float *SKP_restrict b, size_t n, float *buffer_a, float *buffer_b, float*dummy);
85 | // minRMSD metric
86 | float minRMSD_distance(float *SKP_restrict a, float *SKP_restrict b, size_t n, float *SKP_restrict buffer_a, float *SKP_restrict buffer_b,
87 | float* pre_calc_trace_a);
88 |
89 | // assignment to cluster centers from python
90 | PyObject *assign(PyObject *self, PyObject *args);
91 | // assignment to cluster centers from c
92 | int c_assign(float *chunk, float *centers, npy_int32 *dtraj, char* metric,
93 | Py_ssize_t N_frames, Py_ssize_t N_centers, Py_ssize_t dim, int n_threads);
94 |
95 | #ifdef __cplusplus
96 | }
97 | #endif
98 | #endif
99 |
--------------------------------------------------------------------------------
/chainsaw/base/estimator.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import absolute_import, print_function
20 |
21 | from chainsaw._ext.sklearn_base import BaseEstimator as _BaseEstimator
22 | from chainsaw.base.loggable import Loggable as _Loggable
23 |
24 | __author__ = 'noe, marscher'
25 |
26 |
27 | class Estimator(_BaseEstimator, _Loggable):
28 | """ Base class for Chainsaws estimators """
29 | # flag indicating if estimator's estimate method has been called
30 | _estimated = False
31 |
32 | def estimate(self, X, **params):
33 | """ Estimates the model given the data X
34 |
35 | Parameters
36 | ----------
37 | X : object
38 | A reference to the data from which the model will be estimated
39 | params : dict
40 | New estimation parameter values. The parameters must that have been
41 | announced in the __init__ method of this estimator. The present
42 | settings will overwrite the settings of parameters given in the
43 | __init__ method, i.e. the parameter values after this call will be
44 | those that have been used for this estimation. Use this option if
45 | only one or a few parameters change with respect to
46 | the __init__ settings for this run, and if you don't need to
47 | remember the original settings of these changed parameters.
48 |
49 | Returns
50 | -------
51 | estimator : object
52 | The estimated estimator with the model being available.
53 |
54 | """
55 | # set params
56 | if params:
57 | self.set_params(**params)
58 | self._model = self._estimate(X)
59 | self._estimated = True
60 | return self
61 |
62 | def _estimate(self, X):
63 | raise NotImplementedError(
64 | 'You need to overload the _estimate() method in your Estimator implementation!')
65 |
66 | def fit(self, X):
67 | """Estimates parameters - for compatibility with sklearn.
68 |
69 | Parameters
70 | ----------
71 | X : object
72 | A reference to the data from which the model will be estimated
73 |
74 | Returns
75 | -------
76 | estimator : object
77 | The estimator (self) with estimated model.
78 |
79 | """
80 | self.estimate(X)
81 | return self
82 |
83 | @property
84 | def model(self):
85 | """The model estimated by this Estimator"""
86 | try:
87 | return self._model
88 | except AttributeError:
89 | raise AttributeError(
90 | 'Model has not yet been estimated. Call estimate(X) or fit(X) first')
91 |
--------------------------------------------------------------------------------
/setup_util.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of MSMTools.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # MSMTools is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | """
20 | utility functions for python setup
21 | """
22 | import tempfile
23 | import os
24 | import sys
25 | import shutil
26 | from distutils.ccompiler import new_compiler
27 |
28 | # From http://stackoverflow.com/questions/
29 | # 7018879/disabling-output-when-compiling-with-distutils
30 | def hasfunction(cc, funcname):
31 | tmpdir = tempfile.mkdtemp(prefix='hasfunction-')
32 | devnull = oldstderr = None
33 | try:
34 | try:
35 | fname = os.path.join(tmpdir, 'funcname.c')
36 | f = open(fname, 'w')
37 | f.write('int main(void) {\n')
38 | f.write(' %s();\n' % funcname)
39 | f.write('}\n')
40 | f.close()
41 | # Redirect stderr to /dev/null to hide any error messages
42 | # from the compiler.
43 | # This will have to be changed if we ever have to check
44 | # for a function on Windows.
45 | devnull = open('/dev/null', 'w')
46 | oldstderr = os.dup(sys.stderr.fileno())
47 | os.dup2(devnull.fileno(), sys.stderr.fileno())
48 | objects = cc.compile([fname], output_dir=tmpdir)
49 | cc.link_executable(objects, os.path.join(tmpdir, "a.out"))
50 | except:
51 | return False
52 | return True
53 | finally:
54 | if oldstderr is not None:
55 | os.dup2(oldstderr, sys.stderr.fileno())
56 | if devnull is not None:
57 | devnull.close()
58 | shutil.rmtree(tmpdir)
59 |
60 |
61 | def detect_openmp():
62 | compiler = new_compiler()
63 | hasopenmp = hasfunction(compiler, 'omp_get_num_threads')
64 | needs_gomp = hasopenmp
65 | if not hasopenmp:
66 | compiler.add_library('gomp')
67 | hasopenmp = hasfunction(compiler, 'omp_get_num_threads')
68 | needs_gomp = hasopenmp
69 | return hasopenmp, needs_gomp
70 |
71 |
72 | def getSetuptoolsError():
73 | bootstrap_setuptools = """\
74 | python2.7 -c "import urllib2;
75 | url=\'https://bootstrap.pypa.io/ez_setup.py\';\n
76 | exec urllib2.urlopen(url).read()\""""
77 | cmd = ((80 * '=') + '\n' + bootstrap_setuptools + '\n' + (80 * '='))
78 | s = 'You can use the following command to upgrade/install it:\n%s' % cmd
79 | return s
80 |
81 |
82 | class lazy_cythonize(list):
83 | """evaluates extension list lazyly.
84 | pattern taken from http://tinyurl.com/qb8478q"""
85 | def __init__(self, callback):
86 | self._list, self.callback = None, callback
87 | def c_list(self):
88 | if self._list is None: self._list = self.callback()
89 | return self._list
90 | def __iter__(self):
91 | for e in self.c_list(): yield e
92 | def __getitem__(self, ii): return self.c_list()[ii]
93 | def __len__(self): return len(self.c_list())
--------------------------------------------------------------------------------
/chainsaw/base/loggable.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | '''
20 | Created on 30.08.2015
21 |
22 | @author: marscher
23 | '''
24 | from __future__ import absolute_import
25 | import logging
26 | import weakref
27 | from itertools import count
28 |
29 | __all__ = ['Loggable']
30 |
31 |
32 | class Loggable(object):
33 | # counting instances of Loggable, incremented by name property.
34 | __ids = count(0)
35 | # holds weak references to instances of this, to clean up logger instances.
36 | __refs = {}
37 |
38 | _loglevel_DEBUG = logging.DEBUG
39 | _loglevel_INFO = logging.INFO
40 | _loglevel_WARN = logging.WARN
41 | _loglevel_ERROR = logging.ERROR
42 | _loglevel_CRITICAL = logging.CRITICAL
43 |
44 | @property
45 | def name(self):
46 | """The name of this instance"""
47 | try:
48 | return self._name
49 | except AttributeError:
50 | self._name = "%s.%s[%i]" % (self.__module__,
51 | self.__class__.__name__,
52 | next(Loggable.__ids))
53 | return self._name
54 |
55 | @property
56 | def logger(self):
57 | """The logger for this class instance """
58 | try:
59 | return self._logger_instance
60 | except AttributeError:
61 | self.__create_logger()
62 | return self._logger_instance
63 |
64 | @property
65 | def _logger(self):
66 | return self.logger
67 |
68 | def _logger_is_active(self, level):
69 | """ @param level: int log level (debug=10, info=20, warn=30, error=40, critical=50)"""
70 | return self.logger.level >= level
71 |
72 | def __create_logger(self):
73 | _weak_logger_refs = Loggable.__refs
74 | # creates a logger based on the the attribe "name" of self
75 | self._logger_instance = logging.getLogger(self.name)
76 |
77 | # store a weakref to this instance to clean the logger instance.
78 | logger_id = id(self._logger_instance)
79 | r = weakref.ref(self, Loggable._cleanup_logger(logger_id, self.name))
80 | _weak_logger_refs[logger_id] = r
81 |
82 | @staticmethod
83 | def _cleanup_logger(logger_id, logger_name):
84 | # callback function used in conjunction with weakref.ref
85 | # removes logger from root manager
86 |
87 | def remove_logger(weak):
88 | d = logging.getLogger().manager.loggerDict
89 | del d[logger_name]
90 | del Loggable.__refs[logger_id]
91 | return remove_logger
92 |
93 | def __getstate__(self):
94 | # do not pickle the logger instance
95 | d = dict(self.__dict__)
96 | try:
97 | del d['_logger_instance']
98 | except KeyError:
99 | pass
100 | return d
101 |
--------------------------------------------------------------------------------
/devtools/ci/appveyor/transform_xunit_to_appveyor.xsl:
--------------------------------------------------------------------------------
1 |
12 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | Fail
69 | Skip
70 | Pass
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/chainsaw/clustering/assign.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | '''
20 | Created on 18.02.2015
21 |
22 | @author: marscher
23 | '''
24 |
25 | from __future__ import absolute_import
26 |
27 | import numpy as np
28 | import six
29 |
30 | from .interface import AbstractClustering
31 | from chainsaw.util.annotators import fix_docs
32 |
33 |
34 | @fix_docs
35 | class AssignCenters(AbstractClustering):
36 |
37 | """Assigns given (pre-calculated) cluster centers. If you already have
38 | cluster centers from somewhere, you use this class to assign your data to it.
39 |
40 | Parameters
41 | ----------
42 | clustercenters : path to file (csv) or npyfile or ndarray
43 | cluster centers to use in assignment of data
44 | metric : str
45 | metric to use during clustering ('euclidean', 'minRMSD')
46 | stride : int
47 | stride
48 | n_jobs : int or None, default None
49 | Number of threads to use during assignment of the data.
50 | If None, all available CPUs will be used.
51 | skip : int, default=0
52 | skip the first initial n frames per trajectory.
53 | Examples
54 | --------
55 | Assuming you have stored your centers in a CSV file:
56 |
57 | >>> from chainsaw.clustering import AssignCenters
58 | >>> from chainsaw import pipeline
59 | >>> reader = ... # doctest: +SKIP
60 | >>> assign = AssignCenters('my_centers.dat') # doctest: +SKIP
61 | >>> disc = pipeline(reader, cluster=assign) # doctest: +SKIP
62 | >>> disc.parametrize() # doctest: +SKIP
63 |
64 | """
65 |
66 | def __init__(self, clustercenters, metric='euclidean', stride=1, n_jobs=None, skip=0):
67 | super(AssignCenters, self).__init__(metric=metric, n_jobs=n_jobs)
68 |
69 | if isinstance(clustercenters, six.string_types):
70 | from chainsaw.data import create_file_reader
71 | reader = create_file_reader(clustercenters, None, None)
72 | clustercenters = reader.get_output()[0]
73 | else:
74 | clustercenters = np.array(clustercenters, dtype=np.float32, order='C')
75 |
76 | # sanity check.
77 | if not clustercenters.ndim == 2:
78 | raise ValueError('cluster centers have to be 2d')
79 |
80 | self.set_params(clustercenters=clustercenters, metric=metric, stride=stride, skip=skip)
81 |
82 | # since we provided centers, no estimation is required.
83 | self._estimated = True
84 |
85 | def describe(self):
86 | return "[{name} centers shape={shape}]".format(name=self.name, shape=self.clustercenters.shape)
87 |
88 | @AbstractClustering.data_producer.setter
89 | def data_producer(self, dp):
90 | # check dimensions
91 | dim = self.clustercenters.shape[1]
92 | if not dim == dp.dimension():
93 | raise ValueError('cluster centers have wrong dimension. Have dim=%i'
94 | ', but input has %i' % (dim, dp.dimension()))
95 | AbstractClustering.data_producer.fset(self, dp)
96 |
97 | def _estimate(self, iterable, **kw):
98 | old_source = self._data_producer
99 | self.data_producer = iterable
100 | try:
101 | self.assign(None, self.stride)
102 | finally:
103 | self.data_producer = old_source
104 |
105 | return self
106 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_api_load.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Created on 14.04.2015
22 |
23 | @author: marscher
24 | '''
25 |
26 | from __future__ import absolute_import
27 | # unicode compat py2/3
28 | from six import text_type
29 | import unittest
30 | from chainsaw.api import load
31 | import os
32 |
33 | import numpy as np
34 | from chainsaw import api
35 |
36 | import pkg_resources
37 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep
38 |
39 | pdb_file = os.path.join(path, 'bpti_ca.pdb')
40 | traj_files = [
41 | os.path.join(path, 'bpti_001-033.xtc'),
42 | os.path.join(path, 'bpti_067-100.xtc')
43 | ]
44 |
45 |
46 | class TestAPILoad(unittest.TestCase):
47 |
48 | @classmethod
49 | def setUpClass(cls):
50 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep
51 | cls.bpti_pdbfile = os.path.join(path, 'bpti_ca.pdb')
52 | extensions = ['.xtc', '.binpos', '.dcd', '.h5', '.lh5', '.nc', '.netcdf', '.trr']
53 | cls.bpti_mini_files = [os.path.join(path, 'bpti_mini%s' % ext) for ext in extensions]
54 |
55 | def testUnicodeString_without_featurizer(self):
56 | filename = text_type(traj_files[0])
57 |
58 | with self.assertRaises(ValueError):
59 | load(filename)
60 |
61 | def testUnicodeString(self):
62 | filename = text_type(traj_files[0])
63 | features = api.featurizer(pdb_file)
64 |
65 | load(filename, features)
66 |
67 | def test_various_formats_load(self):
68 | chunksizes = [0, 13]
69 | X = None
70 | bpti_mini_previous = None
71 | for cs in chunksizes:
72 | for bpti_mini in self.bpti_mini_files:
73 | Y = api.load(bpti_mini, top=self.bpti_pdbfile, chunk_size=cs)
74 | if X is not None:
75 | np.testing.assert_array_almost_equal(X, Y, err_msg='Comparing %s to %s failed for chunksize %s'
76 | % (bpti_mini, bpti_mini_previous, cs))
77 | X = Y
78 | bpti_mini_previous = bpti_mini
79 |
80 | def test_load_traj(self):
81 | filename = traj_files[0]
82 | features = api.featurizer(pdb_file)
83 | res = load(filename, features)
84 |
85 | self.assertEqual(type(res), np.ndarray)
86 |
87 | def test_load_trajs(self):
88 | features = api.featurizer(pdb_file)
89 | res = load(traj_files, features)
90 |
91 | self.assertEqual(type(res), list)
92 | self.assertTrue(all(type(x) == np.ndarray for x in res))
93 |
94 | def test_with_trajs_without_featurizer_or_top(self):
95 |
96 | with self.assertRaises(ValueError):
97 | load(traj_files)
98 |
99 | output = load(traj_files, top=pdb_file)
100 | self.assertEqual(type(output), list)
101 | self.assertEqual(len(output), len(traj_files))
102 |
103 | def test_non_existant_input(self):
104 | input_files = [traj_files[0], "does_not_exist_for_sure"]
105 |
106 | with self.assertRaises(ValueError):
107 | load(trajfiles=input_files, top=pdb_file)
108 |
109 | def test_empty_list(self):
110 | with self.assertRaises(ValueError):
111 | load([], top=pdb_file)
112 |
113 | if __name__ == "__main__":
114 | unittest.main()
--------------------------------------------------------------------------------
/chainsaw/util/log.py:
--------------------------------------------------------------------------------
1 | # This file is part of PyEMMA.
2 | #
3 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
4 | #
5 | # PyEMMA is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU Lesser General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU Lesser General Public License
16 | # along with this program. If not, see .
17 | '''
18 | Created on 15.10.2013
19 |
20 | @author: marscher
21 | '''
22 |
23 | from __future__ import absolute_import
24 |
25 | import logging
26 | from logging.config import dictConfig
27 | import os.path
28 | import warnings
29 |
30 |
31 | __all__ = ['getLogger',
32 | ]
33 |
34 |
35 | class LoggingConfigurationError(RuntimeError):
36 | pass
37 |
38 |
39 | def setup_logging(config, D=None):
40 | """ set up the logging system with the configured (in chainsaw.cfg) logging config (logging.yml)
41 | @param config: instance of chainsaw.config module (wrapper)
42 | """
43 | if not D:
44 | import yaml
45 |
46 | args = config.logging_config
47 | default = False
48 |
49 | if args.upper() == 'DEFAULT':
50 | default = True
51 | src = config.default_logging_file
52 | else:
53 | src = args
54 |
55 | # first try to read configured file
56 | try:
57 | with open(src) as f:
58 | D = yaml.load(f)
59 | except EnvironmentError as ee:
60 | # fall back to default
61 | if not default:
62 | try:
63 | with open(config.default_logging_file) as f2:
64 | D = yaml.load(f2)
65 | except EnvironmentError as ee2:
66 | raise LoggingConfigurationError('Could not read either configured nor '
67 | 'default logging configuration!\n%s' % ee2)
68 | else:
69 | raise LoggingConfigurationError('could not handle default logging '
70 | 'configuration file\n%s' % ee)
71 |
72 | if D is None:
73 | raise LoggingConfigurationError('Empty logging config! Try using default config by'
74 | ' setting logging_conf=DEFAULT in chainsaw.cfg')
75 | assert D
76 |
77 | # if the user has not explicitly disabled other loggers, we (contrary to Pythons
78 | # default value) do not want to override them.
79 | D.setdefault('disable_existing_loggers', False)
80 |
81 | # configure using the dict
82 | try:
83 | dictConfig(D)
84 | except ValueError as ve:
85 | # issue with file handler?
86 | if 'files' in str(ve) and 'rotating_files' in D['handlers']:
87 | print("cfg dir", config.cfg_dir)
88 | new_file = os.path.join(config.cfg_dir, 'chainsaw.log')
89 | warnings.warn("set logfile to %s, because there was"
90 | " an error writing to the desired one" % new_file)
91 | D['handlers']['rotating_files']['filename'] = new_file
92 | else:
93 | raise
94 | dictConfig(D)
95 |
96 | # get log file name of chainsaw root logger
97 | logger = logging.getLogger('chainsaw')
98 | log_files = [getattr(h, 'baseFilename', None) for h in logger.handlers]
99 |
100 | import atexit
101 | @atexit.register
102 | def clean_empty_log_files():
103 | # gracefully shutdown logging system
104 | logging.shutdown()
105 | for f in log_files:
106 | if f is not None and os.path.exists(f):
107 | try:
108 | if os.stat(f).st_size == 0:
109 | os.remove(f)
110 | except OSError as o:
111 | print("during removal of empty logfiles there was a problem: ", o)
112 |
113 |
--------------------------------------------------------------------------------
/chainsaw/clustering/uniform_time.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | from __future__ import absolute_import, division
21 |
22 | import math
23 |
24 | import numpy as np
25 |
26 | from .interface import AbstractClustering
27 | from chainsaw.util.annotators import fix_docs
28 |
29 | __author__ = 'noe'
30 | __all__ = ['UniformTimeClustering']
31 |
32 |
33 | @fix_docs
34 | class UniformTimeClustering(AbstractClustering):
35 | r"""Uniform time clustering"""
36 |
37 | def __init__(self, n_clusters=2, metric='euclidean', stride=1, n_jobs=None, skip=0):
38 | """r
39 | Uniform time clustering
40 |
41 | Parameters
42 | ----------
43 | n_clusters : int
44 | amount of desired cluster centers. When not specified (None),
45 | min(sqrt(N), 5000) is chosen as default value,
46 | where N denotes the number of data points
47 | metric : str
48 | metric to use during clustering ('euclidean', 'minRMSD')
49 | stride : int
50 | stride
51 | n_jobs : int or None, default None
52 | Number of threads to use during assignment of the data.
53 | If None, all available CPUs will be used.
54 | skip : int, default=0
55 | skip the first initial n frames per trajectory.
56 | """
57 | super(UniformTimeClustering, self).__init__(metric=metric, n_jobs=n_jobs)
58 | self.set_params(n_clusters=n_clusters, metric=metric, stride=stride, skip=skip)
59 |
60 | def describe(self):
61 | return "[Uniform time clustering, k = %i, inp_dim=%i]" \
62 | % (self.n_clusters, self.data_producer.dimension())
63 |
64 | def _estimate(self, iterable, **kw):
65 |
66 | if self.n_clusters is None:
67 | traj_lengths = self.trajectory_lengths(stride=self.stride, skip=self.skip)
68 | total_length = sum(traj_lengths)
69 | self.n_clusters = min(int(math.sqrt(total_length)), 5000)
70 | self._logger.info("The number of cluster centers was not specified, "
71 | "using min(sqrt(N), 5000)=%s as n_clusters." % self.n_clusters)
72 |
73 | # initialize time counters
74 | T = iterable.n_frames_total(stride=self.stride, skip=self.skip)
75 | if self.n_clusters > T:
76 | self.n_clusters = T
77 | self._logger.info('Requested more clusters (k = %i'
78 | ' than there are total data points %i)'
79 | '. Will do clustering with k = %i'
80 | % (self.n_clusters, T, T))
81 |
82 | # first data point in the middle of the time segment
83 | next_t = (T // self.n_clusters) // 2
84 | # cumsum of lenghts
85 | cumsum = np.cumsum(self.trajectory_lengths(skip=self.skip))
86 | # distribution of integers, truncate if n_clusters is too large
87 | linspace = self.stride * np.arange(next_t, T - next_t + 1, (T - 2*next_t + 1) // self.n_clusters)[:self.n_clusters]
88 | # random access matrix
89 | ra_stride = np.array([UniformTimeClustering._idx_to_traj_idx(x, cumsum) for x in linspace])
90 | with iterable.iterator(stride=ra_stride, return_trajindex=False, chunk=self.chunksize, skip=self.skip) as it:
91 | self.clustercenters = np.concatenate([X for X in it])
92 |
93 | assert len(self.clustercenters) == self.n_clusters
94 | return self
95 |
96 | @staticmethod
97 | def _idx_to_traj_idx(idx, cumsum):
98 | prev_len = 0
99 | for trajIdx, length in enumerate(cumsum):
100 | if prev_len <= idx < length:
101 | return trajIdx, idx - prev_len
102 | prev_len = length
103 | raise ValueError("Requested index %s was out of bounds [0,%s)" % (idx, cumsum[-1]))
104 |
--------------------------------------------------------------------------------
/chainsaw/tests/data/bpti_ca.pdb:
--------------------------------------------------------------------------------
1 | ATOM 2 CA ARG A 1 5.137 10.135 0.877 0.00 0.00 C1 C
2 | ATOM 28 CA PRO A 2 6.740 7.345 -1.220 0.00 0.00 C1 C
3 | ATOM 42 CA ASP A 3 4.355 4.870 -2.726 0.00 0.00 C1 C
4 | ATOM 54 CA PHE A 4 5.732 1.868 -0.668 0.00 0.00 C1 C
5 | ATOM 74 CA CYS A 5 5.395 3.903 2.491 0.00 0.00 C1 C
6 | ATOM 84 CA LEU A 6 1.634 4.834 1.750 0.00 0.00 C1 C
7 | ATOM 103 CA GLU A 7 0.330 1.215 1.642 0.00 0.00 C1 C
8 | ATOM 118 CA PRO A 8 -1.381 -1.343 4.103 0.00 0.00 C1 C
9 | ATOM 132 CA PRO A 9 0.785 -4.139 5.864 0.00 0.00 C1 C
10 | ATOM 146 CA TYR A 10 0.338 -7.965 5.239 0.00 0.00 C1 C
11 | ATOM 167 CA THR A 11 1.073 -10.832 7.801 0.00 0.00 C1 C
12 | ATOM 181 CA GLY A 12 1.165 -13.584 4.948 0.00 0.00 C1 C
13 | ATOM 188 CA PRO A 13 0.230 -17.400 5.025 0.00 0.00 C1 C
14 | ATOM 202 CA CYS A 14 3.079 -18.621 7.107 0.00 0.00 C1 C
15 | ATOM 212 CA LYS A 15 2.582 -19.937 10.866 0.00 0.00 C1 C
16 | ATOM 234 CA ALA A 16 4.792 -17.419 12.911 0.00 0.00 C1 C
17 | ATOM 244 CA ARG A 17 3.794 -14.113 14.750 0.00 0.00 C1 C
18 | ATOM 268 CA ILE A 18 7.191 -12.307 14.496 0.00 0.00 C1 C
19 | ATOM 287 CA ILE A 19 7.143 -8.504 14.984 0.00 0.00 C1 C
20 | ATOM 306 CA ARG A 20 8.006 -6.620 11.818 0.00 0.00 C1 C
21 | ATOM 330 CA TYR A 21 7.877 -2.855 11.290 0.00 0.00 C1 C
22 | ATOM 351 CA PHE A 22 5.979 -0.736 8.661 0.00 0.00 C1 C
23 | ATOM 371 CA TYR A 23 4.972 2.928 7.941 0.00 0.00 C1 C
24 | ATOM 392 CA ASN A 24 1.390 3.611 9.146 0.00 0.00 C1 C
25 | ATOM 406 CA ALA A 25 0.700 6.390 6.695 0.00 0.00 C1 C
26 | ATOM 416 CA LYS A 26 -2.664 7.139 8.539 0.00 0.00 C1 C
27 | ATOM 438 CA ALA A 27 -0.686 8.124 11.724 0.00 0.00 C1 C
28 | ATOM 448 CA GLY A 28 2.403 9.458 9.885 0.00 0.00 C1 C
29 | ATOM 455 CA LEU A 29 4.794 7.134 11.896 0.00 0.00 C1 C
30 | ATOM 474 CA CYS A 30 6.628 3.768 11.766 0.00 0.00 C1 C
31 | ATOM 484 CA GLN A 31 4.809 1.106 13.924 0.00 0.00 C1 C
32 | ATOM 501 CA THR A 32 4.558 -2.660 14.660 0.00 0.00 C1 C
33 | ATOM 515 CA PHE A 33 2.741 -5.558 13.036 0.00 0.00 C1 C
34 | ATOM 535 CA VAL A 34 2.912 -9.385 13.166 0.00 0.00 C1 C
35 | ATOM 551 CA TYR A 35 4.160 -11.476 10.334 0.00 0.00 C1 C
36 | ATOM 572 CA GLY A 36 4.181 -15.187 9.310 0.00 0.00 C1 C
37 | ATOM 579 CA GLY A 37 7.863 -15.378 8.107 0.00 0.00 C1 C
38 | ATOM 586 CA CYS A 38 7.170 -15.482 4.311 0.00 0.00 C1 C
39 | ATOM 596 CA ARG A 39 6.097 -13.526 1.228 0.00 0.00 C1 C
40 | ATOM 620 CA ALA A 40 6.717 -9.865 2.308 0.00 0.00 C1 C
41 | ATOM 630 CA LYS A 41 4.813 -6.980 0.893 0.00 0.00 C1 C
42 | ATOM 652 CA ARG A 42 6.897 -3.891 0.135 0.00 0.00 C1 C
43 | ATOM 676 CA ASN A 43 5.975 -2.121 3.481 0.00 0.00 C1 C
44 | ATOM 690 CA ASN A 44 7.760 -4.726 5.781 0.00 0.00 C1 C
45 | ATOM 704 CA PHE A 45 10.990 -3.890 7.611 0.00 0.00 C1 C
46 | ATOM 724 CA LYS A 46 13.280 -5.405 10.215 0.00 0.00 C1 C
47 | ATOM 746 CA SER A 47 13.349 -2.293 12.285 0.00 0.00 C1 C
48 | ATOM 757 CA ALA A 48 11.836 1.205 12.902 0.00 0.00 C1 C
49 | ATOM 767 CA GLU A 49 15.549 2.323 12.263 0.00 0.00 C1 C
50 | ATOM 782 CA ASP A 50 15.332 0.878 8.725 0.00 0.00 C1 C
51 | ATOM 794 CA CYS A 51 11.618 2.060 8.308 0.00 0.00 C1 C
52 | ATOM 804 CA MET A 52 12.708 5.603 9.261 0.00 0.00 C1 C
53 | ATOM 821 CA ARG A 53 15.884 5.306 7.040 0.00 0.00 C1 C
54 | ATOM 845 CA THR A 54 13.448 4.173 4.196 0.00 0.00 C1 C
55 | ATOM 859 CA CYS A 55 10.100 6.043 4.700 0.00 0.00 C1 C
56 | ATOM 869 CA GLY A 56 11.076 9.290 6.702 0.00 0.00 C1 C
57 | ATOM 876 CA GLY A 57 10.474 11.834 3.826 0.00 0.00 C1 C
58 | ATOM 883 CA ALA A 58 7.286 13.898 2.917 0.00 0.00 C1 C
59 |
--------------------------------------------------------------------------------
/chainsaw/base/model.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import absolute_import
20 | import warnings
21 |
22 | from chainsaw.util.reflection import getargspec_no_self
23 |
24 | __author__ = 'noe'
25 |
26 |
27 | class Model(object):
28 | """ Base class for Chainsaws models
29 |
30 | This class is inspired by sklearn's BaseEstimator class. However, we define parameter names not by the
31 | current class' __init__ but have to announce them. This allows us to also remember the parameters of model
32 | superclasses. This class can be mixed with Chainsaw and sklearn Estimators.
33 |
34 | """
35 |
36 | def _get_model_param_names(self):
37 | r"""Get parameter names for the model"""
38 | # fetch model parameters
39 | if hasattr(self, 'set_model_params'):
40 | set_model_param_method = getattr(self, 'set_model_params')
41 | # introspect the constructor arguments to find the model parameters
42 | # to represent
43 | args, varargs, kw, default = getargspec_no_self(set_model_param_method)
44 | if varargs is not None:
45 | raise RuntimeError("Models should always specify their parameters in the signature"
46 | " of their set_model_params (no varargs). %s doesn't follow this convention."
47 | % (self, ))
48 | # Remove 'self'
49 | # XXX: This is going to fail if the init is a staticmethod, but
50 | # who would do this?
51 | args.pop(0)
52 | args.sort()
53 | return args
54 | else:
55 | # No parameters known
56 | return []
57 |
58 | def update_model_params(self, **params):
59 | r"""Update given model parameter if they are set to specific values"""
60 | for key, value in list(params.items()):
61 | if not hasattr(self, key):
62 | setattr(self, key, value) # set parameter for the first time.
63 | elif getattr(self, key) is None:
64 | setattr(self, key, value) # update because this parameter is still None.
65 | elif value is not None:
66 | setattr(self, key, value) # only overwrite if set to a specific value (None does not overwrite).
67 |
68 | def get_model_params(self, deep=True):
69 | r"""Get parameters for this model.
70 |
71 | Parameters
72 | ----------
73 | deep: boolean, optional
74 | If True, will return the parameters for this estimator and
75 | contained subobjects that are estimators.
76 | Returns
77 | -------
78 | params : mapping of string to any
79 | Parameter names mapped to their values.
80 | """
81 | out = dict()
82 | for key in self._get_model_param_names():
83 | # We need deprecation warnings to always be on in order to
84 | # catch deprecated param values.
85 | # This is set in utils/__init__.py but it gets overwritten
86 | # when running under python3 somehow.
87 | warnings.simplefilter("always", DeprecationWarning)
88 | try:
89 | with warnings.catch_warnings(record=True) as w:
90 | value = getattr(self, key, None)
91 | if len(w) and w[0].category == DeprecationWarning:
92 | # if the parameter is deprecated, don't show it
93 | continue
94 | finally:
95 | warnings.filters.pop(0)
96 |
97 | # XXX: should we rather test if instance of estimator?
98 | if deep and hasattr(value, 'get_params'):
99 | deep_items = list(value.get_params().items())
100 | out.update((key + '__' + k, val) for k, val in deep_items)
101 | out[key] = value
102 | return out
103 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_cache.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import unittest
4 | from glob import glob
5 |
6 | import numpy as np
7 | from chainsaw import config
8 |
9 | from chainsaw.data.cache import Cache
10 | import chainsaw
11 |
12 | from logging import getLogger
13 |
14 | class ProfilerCase(unittest.TestCase):
15 | def setUp(self):
16 | import cProfile
17 | self.pr = cProfile.Profile()
18 | self.pr.enable()
19 |
20 | def tearDown(self):
21 | self.pr.dump_stats(self._testMethodName)
22 |
23 |
24 | class TestCache(unittest.TestCase):
25 |
26 | @classmethod
27 | def setUpClass(cls):
28 | config.use_trajectory_lengths_cache = False
29 | cls.test_dir = tempfile.mkdtemp(prefix="test_cache_")
30 |
31 | cls.length = 1000
32 | cls.dim = 3
33 | data = [np.random.random((cls.length, cls.dim)) for _ in range(10)]
34 |
35 | for i, x in enumerate(data):
36 | np.save(os.path.join(cls.test_dir, "{}.npy".format(i)), x)
37 |
38 | cls.files = glob(cls.test_dir + "/*.npy")
39 |
40 | def setUp(self):
41 | super(TestCache, self).setUp()
42 | self.tmp_cache_dir = tempfile.mkdtemp(dir=self.test_dir)
43 | config.cache_dir = self.tmp_cache_dir
44 | self.logger = getLogger("chainsaw.test.%s"%self.id())
45 |
46 | @classmethod
47 | def tearDownClass(cls):
48 | import shutil
49 | shutil.rmtree(cls.test_dir, ignore_errors=False)
50 |
51 | def test_cache_hits(self):
52 | src = chainsaw.source(self.files, chunk_size=1000)
53 | src.describe()
54 | data = src.get_output()
55 | cache = Cache(src)
56 | self.assertEqual(cache.data.misses, 0)
57 | out = cache.get_output()
58 | self.assertEqual(cache.data.misses, len(self.files))
59 |
60 | for actual, desired in zip(out, data):
61 | np.testing.assert_allclose(actual, desired, atol=1e-15, rtol=1e-7)
62 |
63 | cache.get_output()
64 | self.assertEqual(len(cache.data.hits), len(self.files))
65 |
66 | #self.assertIn("items={}".format(len(cache)), repr(cache))
67 |
68 | def test_get_output(self):
69 | src = chainsaw.source(self.files, chunk_size=0)
70 | dim = 1
71 | stride = 2
72 | skip = 3
73 | desired = src.get_output(stride=stride, dimensions=dim, skip=skip)
74 |
75 | cache = Cache(src)
76 | actual = cache.get_output(stride=stride, dimensions=dim, skip=skip)
77 | np.testing.assert_allclose(actual, desired)
78 |
79 | def test_tica_cached_input(self):
80 | src = chainsaw.source(self.files, chunk_size=0)
81 | cache = Cache(src)
82 |
83 | tica_cache_inp = chainsaw.tica(cache)
84 |
85 | tica_without_cache = chainsaw.tica(src)
86 |
87 | np.testing.assert_allclose(tica_cache_inp.cov, tica_without_cache.cov, atol=1e-10)
88 | np.testing.assert_allclose(tica_cache_inp.cov_tau, tica_without_cache.cov_tau, atol=1e-9)
89 |
90 | np.testing.assert_allclose(tica_cache_inp.eigenvalues, tica_without_cache.eigenvalues, atol=1e-7)
91 | np.testing.assert_allclose(np.abs(tica_cache_inp.eigenvectors),
92 | np.abs(tica_without_cache.eigenvectors), atol=1e-6)
93 |
94 | def test_tica_cached_output(self):
95 | src = chainsaw.source(self.files, chunk_size=0)
96 | tica = chainsaw.tica(src, dim=2)
97 |
98 | tica_output = tica.get_output()
99 | cache = Cache(tica)
100 |
101 | np.testing.assert_allclose(cache.get_output(), tica_output)
102 |
103 | def test_cache_switch_cache_file(self):
104 | src = chainsaw.source(self.files, chunk_size=0)
105 | t = chainsaw.tica(src, dim=2)
106 | cache = Cache(t)
107 |
108 | def test_with_feature_reader_switch_cache_file(self):
109 | import pkg_resources
110 | path = pkg_resources.resource_filename(__name__, 'data') + os.path.sep
111 | pdbfile = os.path.join(path, 'bpti_ca.pdb')
112 | trajfiles = os.path.join(path, 'bpti_mini.xtc')
113 |
114 | reader = chainsaw.source(trajfiles, top=pdbfile)
115 | reader.featurizer.add_selection([0,1, 2])
116 |
117 | cache = Cache(reader)
118 | name_of_cache = cache.current_cache_file_name
119 |
120 | reader.featurizer.add_selection([5, 8, 9])
121 | new_name_of_cache = cache.current_cache_file_name
122 |
123 | self.assertNotEqual(name_of_cache, new_name_of_cache)
124 |
125 | # remove 2nd feature and check we've got the old name back.
126 | reader.featurizer.active_features.pop()
127 | self.assertEqual(cache.current_cache_file_name, name_of_cache)
128 |
129 | # add a new file and ensure we still got the same cache file
130 | reader.filenames.append(os.path.join(path, 'bpti_001-033.xtc'))
131 | self.assertEqual(cache.current_cache_file_name, name_of_cache)
132 |
133 |
134 | if __name__ == '__main__':
135 | unittest.main()
136 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_featurereader_and_tica.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Test feature reader and Tica with a set of cosine time series.
22 | @author: Fabian Paul
23 | '''
24 |
25 | from __future__ import absolute_import
26 | from __future__ import print_function
27 |
28 | import os
29 | import tempfile
30 | import unittest
31 | from logging import getLogger
32 |
33 | import mdtraj
34 | import numpy as np
35 | from chainsaw import api
36 | from chainsaw.data.md.feature_reader import FeatureReader
37 | from six.moves import range
38 |
39 | log = getLogger('chainsaw.'+'TestFeatureReaderAndTICA')
40 |
41 |
42 | class TestFeatureReaderAndTICA(unittest.TestCase):
43 | @classmethod
44 | def setUpClass(cls):
45 | cls.dim = 9 # dimension (must be divisible by 3)
46 | N = 50000 # length of single trajectory # 500000
47 | N_trajs = 10 # number of trajectories
48 |
49 | cls.w = 2.0*np.pi*1000.0/N # have 1000 cycles in each trajectory
50 |
51 | # get random amplitudes and phases
52 | cls.A = np.random.randn(cls.dim)
53 | cls.phi = np.random.random_sample((cls.dim,))*np.pi*2.0
54 | mean = np.random.randn(cls.dim)
55 |
56 | # create topology file
57 | cls.temppdb = tempfile.mktemp('.pdb')
58 | with open(cls.temppdb, 'w') as f:
59 | for i in range(cls.dim//3):
60 | print(('ATOM %5d C ACE A 1 28.490 31.600 33.379 0.00 1.00' % i), file=f)
61 |
62 | t = np.arange(0, N)
63 | t_total = 0
64 | cls.trajnames = [] # list of xtc file names
65 | for i in range(N_trajs):
66 | # set up data
67 | data = cls.A*np.cos((cls.w*(t+t_total))[:, np.newaxis]+cls.phi) + mean
68 | xyz = data.reshape((N, cls.dim//3, 3))
69 | # create trajectory file
70 | traj = mdtraj.load(cls.temppdb)
71 | traj.xyz = xyz
72 | traj.time = t
73 | tempfname = tempfile.mktemp('.xtc')
74 | traj.save(tempfname)
75 | cls.trajnames.append(tempfname)
76 | t_total += N
77 |
78 | @classmethod
79 | def tearDownClass(cls):
80 | for fname in cls.trajnames:
81 | os.unlink(fname)
82 | os.unlink(cls.temppdb)
83 | super(TestFeatureReaderAndTICA, cls).tearDownClass()
84 |
85 | def test_covariances_and_eigenvalues(self):
86 | reader = FeatureReader(self.trajnames, self.temppdb, chunksize=10000)
87 | for lag in [1, 11, 101, 1001, 2001]: # avoid cos(w*tau)==0
88 | trans = api.tica(data=reader, dim=self.dim, lag=lag)
89 | log.info('number of trajectories reported by tica %d' % trans.number_of_trajectories())
90 | log.info('tau = %d corresponds to a number of %f cycles' % (lag, self.w*lag/(2.0*np.pi)))
91 | trans.parametrize()
92 |
93 | # analytical solution for C_ij(lag) is 0.5*A[i]*A[j]*cos(phi[i]-phi[j])*cos(w*lag)
94 | ana_cov = 0.5*self.A[:, np.newaxis]*self.A*np.cos(self.phi[:, np.newaxis]-self.phi)
95 | ana_cov_tau = ana_cov*np.cos(self.w*lag)
96 |
97 | self.assertTrue(np.allclose(ana_cov, trans.cov, atol=1.E-3))
98 | self.assertTrue(np.allclose(ana_cov_tau, trans.cov_tau, atol=1.E-3))
99 | log.info('max. eigenvalue: %f' % np.max(trans.eigenvalues))
100 | self.assertTrue(np.all(trans.eigenvalues <= 1.0))
101 |
102 | def test_partial_fit(self):
103 | reader = FeatureReader(self.trajnames, self.temppdb, chunksize=10000)
104 | output = reader.get_output()
105 | params = {'dim': self.dim, 'lag': 1001}
106 | ref = api.tica(reader, **params)
107 | partial = api.tica(**params)
108 |
109 | for traj in output:
110 | partial.partial_fit(traj)
111 |
112 | np.testing.assert_allclose(partial.eigenvalues, ref.eigenvalues)
113 | # only compare first two eigenvectors, because we only have two metastable processes
114 | np.testing.assert_allclose(np.abs(partial.eigenvectors[:2]),
115 | np.abs(ref.eigenvectors[:2]), rtol=1e-3, atol=1e-3)
116 |
117 | if __name__ == "__main__":
118 | unittest.main()
119 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_stride.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 |
21 | from __future__ import print_function
22 |
23 | from __future__ import absolute_import
24 | import unittest
25 | import os
26 | import tempfile
27 | import numpy as np
28 | import mdtraj
29 | import chainsaw as coor
30 | from six.moves import range
31 | from six.moves import zip
32 |
33 | class TestStride(unittest.TestCase):
34 | @classmethod
35 | def setUpClass(cls):
36 | cls.dim = 3 # dimension (must be divisible by 3)
37 | N_trajs = 10 # number of trajectories
38 |
39 | # create topology file
40 | cls.temppdb = tempfile.mktemp('.pdb')
41 | with open(cls.temppdb, 'w') as f:
42 | for i in range(cls.dim//3):
43 | print(('ATOM %5d C ACE A 1 28.490 31.600 33.379 0.00 1.00' % i), file=f)
44 |
45 | cls.trajnames = [] # list of xtc file names
46 | cls.data = []
47 | for i in range(N_trajs):
48 | # set up data
49 | N = int(np.random.rand()*1000+1000)
50 | xyz = np.random.randn(N, cls.dim//3, 3).astype(np.float32)
51 | cls.data.append(xyz)
52 | t = np.arange(0, N)
53 | # create trajectory file
54 | traj = mdtraj.load(cls.temppdb)
55 | traj.xyz = xyz
56 | traj.time = t
57 | tempfname = tempfile.mktemp('.xtc')
58 | traj.save(tempfname)
59 | cls.trajnames.append(tempfname)
60 |
61 | def test_length_and_content_feature_reader_and_TICA(self):
62 | for stride in range(1, 100, 23):
63 | r = coor.source(self.trajnames, top=self.temppdb)
64 | t = coor.tica(data=r, lag=2, dim=2, force_eigenvalues_le_one=True)
65 | # t.data_producer = r
66 | t.parametrize()
67 |
68 | # subsample data
69 | out_tica = t.get_output(stride=stride)
70 | out_reader = r.get_output(stride=stride)
71 |
72 | # get length in different ways
73 | len_tica = [x.shape[0] for x in out_tica]
74 | len_reader = [x.shape[0] for x in out_reader]
75 | len_trajs = t.trajectory_lengths(stride=stride)
76 | len_ref = [(x.shape[0]-1)//stride+1 for x in self.data]
77 | # print 'len_ref', len_ref
78 |
79 | # compare length
80 | np.testing.assert_equal(len_trajs, len_ref)
81 | self.assertTrue(len_ref == len_tica)
82 | self.assertTrue(len_ref == len_reader)
83 |
84 | # compare content (reader)
85 | for ref_data, test_data in zip(self.data, out_reader):
86 | ref_data_reshaped = ref_data.reshape((ref_data.shape[0], ref_data.shape[1]*3))
87 | self.assertTrue(np.allclose(ref_data_reshaped[::stride, :], test_data, atol=1E-3))
88 |
89 | def test_content_data_in_memory(self):
90 | # prepare test data
91 | N_trajs = 10
92 | d = []
93 | for _ in range(N_trajs):
94 | N = int(np.random.rand()*1000+10)
95 | d.append(np.random.randn(N, 10).astype(np.float32))
96 |
97 | # read data
98 | reader = coor.source(d)
99 |
100 | # compare
101 | for stride in range(1, 10, 3):
102 | out_reader = reader.get_output(stride=stride)
103 | for ref_data, test_data in zip(d, out_reader):
104 | self.assertTrue(np.all(ref_data[::stride] == test_data)) # here we can test exact equality
105 |
106 | def test_parametrize_with_stride(self):
107 | for stride in range(1, 100, 23):
108 | r = coor.source(self.trajnames, top=self.temppdb)
109 | tau = 5
110 | try:
111 | t = coor.tica(r, lag=tau, stride=stride, dim=2, force_eigenvalues_le_one=True)
112 | # force_eigenvalues_le_one=True enables an internal consistency check in TICA
113 | t.parametrize(stride=stride)
114 | self.assertTrue(np.all(t.eigenvalues <= 1.0+1.E-12))
115 | except RuntimeError:
116 | assert tau % stride != 0
117 |
118 | @classmethod
119 | def tearDownClass(cls):
120 | for fname in cls.trajnames:
121 | os.unlink(fname)
122 | os.unlink(cls.temppdb)
123 | super(TestStride, cls).tearDownClass()
124 |
125 | if __name__ == "__main__":
126 | unittest.main()
--------------------------------------------------------------------------------
/chainsaw/tests/test_regspace.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Created on 26.01.2015
22 |
23 | @author: marscher
24 | '''
25 |
26 | from __future__ import absolute_import
27 | import itertools
28 | import unittest
29 |
30 | from chainsaw.clustering.regspace import RegularSpaceClustering
31 | from chainsaw.data.data_in_memory import DataInMemory
32 | from chainsaw.api import cluster_regspace
33 |
34 | import numpy as np
35 | from chainsaw.util import types
36 |
37 |
38 | class RandomDataSource(DataInMemory):
39 |
40 | def __init__(self, a=None, b=None, chunksize=100, n_samples=1000, dim=3):
41 | """
42 | creates random values in interval [a,b]
43 | """
44 | data = np.random.random((n_samples, dim))
45 | if a is not None and b is not None:
46 | data *= (b - a)
47 | data += a
48 | super(RandomDataSource, self).__init__(data, chunksize=chunksize)
49 |
50 |
51 | class TestRegSpaceClustering(unittest.TestCase):
52 |
53 | def setUp(self):
54 | self.dmin = 0.3
55 | self.clustering = RegularSpaceClustering(dmin=self.dmin)
56 | self.clustering.data_producer = RandomDataSource()
57 |
58 | def test_algorithm(self):
59 | self.clustering.parametrize()
60 |
61 | # correct type of dtrajs
62 | assert types.is_int_vector(self.clustering.dtrajs[0])
63 |
64 | # assert distance for each centroid is at least dmin
65 | for c in itertools.combinations(self.clustering.clustercenters, 2):
66 | if np.allclose(c[0], c[1]): # skip equal pairs
67 | continue
68 |
69 | dist = np.linalg.norm(c[0] - c[1], 2)
70 |
71 | self.assertGreaterEqual(dist, self.dmin,
72 | "centroid pair\n%s\n%s\n has smaller"
73 | " distance than dmin(%f): %f"
74 | % (c[0], c[1], self.dmin, dist))
75 |
76 | def test_assignment(self):
77 | self.clustering.parametrize()
78 |
79 | assert len(self.clustering.clustercenters) > 1
80 |
81 | # num states == num _clustercenters?
82 | self.assertEqual(len(np.unique(self.clustering.dtrajs)), len(
83 | self.clustering.clustercenters), "number of unique states in dtrajs"
84 | " should be equal.")
85 |
86 | data_to_cluster = np.random.random((1000, 3))
87 |
88 | self.clustering.assign(data_to_cluster, stride=1)
89 |
90 | def test_spread_data(self):
91 | self.clustering.data_producer = RandomDataSource(a=-2, b=2)
92 | self.clustering.dmin = 2
93 | self.clustering.parametrize()
94 |
95 | def test1d_data(self):
96 | data = np.random.random(100)
97 | cluster_regspace(data, dmin=0.3)
98 |
99 | def test_non_existent_metric(self):
100 | self.clustering.data_producer = RandomDataSource(a=-2, b=2)
101 | self.clustering.dmin = 2
102 | self.clustering.metric = "non_existent_metric"
103 | with self.assertRaises(ValueError):
104 | self.clustering.parametrize()
105 |
106 | def test_minRMSD_metric(self):
107 | self.clustering.data_producer = RandomDataSource(a=-2, b=2)
108 | self.clustering.dmin = 2
109 | self.clustering.metric = "minRMSD"
110 | self.clustering.parametrize()
111 |
112 | data_to_cluster = np.random.random((1000, 3))
113 |
114 | self.clustering.assign(data_to_cluster, stride=1)
115 |
116 | def test_too_small_dmin_should_warn(self):
117 | self.clustering.dmin = 1e-8
118 | max_centers = 50
119 | self.clustering.max_centers = max_centers
120 | import warnings
121 | with warnings.catch_warnings(record=True) as w:
122 | # Cause all warnings to always be triggered.
123 | warnings.simplefilter("always")
124 | # Trigger a warning.
125 | self.clustering.estimate(self.clustering.data_producer)
126 | assert w
127 | assert len(w) == 1
128 |
129 | assert len(self.clustering.clustercenters) == max_centers
130 |
131 | # assign data
132 | out = self.clustering.get_output()
133 | assert len(out) == self.clustering.number_of_trajectories()
134 | assert len(out[0]) == self.clustering.trajectory_lengths()[0]
135 |
136 | if __name__ == "__main__":
137 | unittest.main()
138 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_discretizer.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | '''
20 | Created on 19.01.2015
21 |
22 | @author: marscher
23 | '''
24 |
25 | from __future__ import absolute_import
26 | import os
27 | import tempfile
28 | import unittest
29 | import mdtraj
30 | import numpy as np
31 |
32 | from mdtraj.core.trajectory import Trajectory
33 | from mdtraj.core.element import hydrogen, oxygen
34 | from mdtraj.core.topology import Topology
35 |
36 | from chainsaw.clustering.uniform_time import UniformTimeClustering
37 | from chainsaw.pipelines import Discretizer
38 | from chainsaw.data.data_in_memory import DataInMemory
39 | from chainsaw.api import cluster_kmeans, pca, source
40 | from six.moves import range
41 |
42 |
43 | def create_water_topology_on_disc(n):
44 | topfile = tempfile.mktemp('.pdb')
45 | top = Topology()
46 | chain = top.add_chain()
47 |
48 | for i in range(n):
49 | res = top.add_residue('r%i' % i, chain)
50 | h1 = top.add_atom('H', hydrogen, res)
51 | o = top.add_atom('O', oxygen, res)
52 | h2 = top.add_atom('H', hydrogen, res)
53 | top.add_bond(h1, o)
54 | top.add_bond(h2, o)
55 |
56 | xyz = np.zeros((n * 3, 3))
57 | Trajectory(xyz, top).save_pdb(topfile)
58 | return topfile
59 |
60 |
61 | def create_traj_on_disc(topfile, n_frames, n_atoms):
62 | fn = tempfile.mktemp('.xtc')
63 | xyz = np.random.random((n_frames, n_atoms, 3))
64 | t = mdtraj.load(topfile)
65 | t.xyz = xyz
66 | t.time = np.arange(n_frames)
67 | t.save(fn)
68 | return fn
69 |
70 |
71 | class TestDiscretizer(unittest.TestCase):
72 |
73 | @classmethod
74 | def setUpClass(cls):
75 | c = super(TestDiscretizer, cls).setUpClass()
76 | # create a fake trajectory which has 2 atoms and coordinates are just a range
77 | # over all frames.
78 | cls.n_frames = 1000
79 | cls.n_residues = 30
80 | cls.topfile = create_water_topology_on_disc(cls.n_residues)
81 |
82 | # create some trajectories
83 | t1 = create_traj_on_disc(
84 | cls.topfile, cls.n_frames, cls.n_residues * 3)
85 |
86 | t2 = create_traj_on_disc(
87 | cls.topfile, cls.n_frames, cls.n_residues * 3)
88 |
89 | cls.trajfiles = [t1, t2]
90 |
91 | cls.dest_dir = tempfile.mkdtemp()
92 |
93 | return c
94 |
95 | @classmethod
96 | def tearDownClass(cls):
97 | """delete temporary files"""
98 | os.unlink(cls.topfile)
99 | for f in cls.trajfiles:
100 | os.unlink(f)
101 |
102 | import shutil
103 | shutil.rmtree(cls.dest_dir, ignore_errors=True)
104 |
105 | def test(self):
106 | reader = source(self.trajfiles, top=self.topfile)
107 | pcat = pca(dim=2)
108 |
109 | n_clusters = 2
110 | clustering = UniformTimeClustering(n_clusters=n_clusters)
111 |
112 | D = Discretizer(reader, transform=pcat, cluster=clustering)
113 | D.parametrize()
114 |
115 | self.assertEqual(len(D.dtrajs), len(self.trajfiles))
116 |
117 | for dtraj in clustering.dtrajs:
118 | unique = np.unique(dtraj)
119 | self.assertEqual(unique.shape[0], n_clusters)
120 |
121 | def test_with_data_in_mem(self):
122 | import chainsaw as api
123 |
124 | data = [np.random.random((100, 50)),
125 | np.random.random((103, 50)),
126 | np.random.random((33, 50))]
127 | reader = source(data)
128 | assert isinstance(reader, DataInMemory)
129 |
130 | tpca = api.pca(dim=2)
131 |
132 | n_centers = 10
133 | km = api.cluster_kmeans(k=n_centers)
134 |
135 | disc = api.discretizer(reader, tpca, km)
136 | disc.parametrize()
137 |
138 | dtrajs = disc.dtrajs
139 | for dtraj in dtrajs:
140 | n_states = np.max((np.unique(dtraj)))
141 | self.assertGreaterEqual(n_centers - 1, n_states,
142 | "dtraj has more states than cluster centers")
143 |
144 | def test_save_dtrajs(self):
145 | reader = source(self.trajfiles, top=self.topfile)
146 | cluster = cluster_kmeans(k=2)
147 | d = Discretizer(reader, cluster=cluster)
148 | d.parametrize()
149 | d.save_dtrajs(output_dir=self.dest_dir)
150 | dtrajs = os.listdir(self.dest_dir)
151 |
152 |
153 | if __name__ == "__main__":
154 | unittest.main()
155 |
--------------------------------------------------------------------------------
/chainsaw/acf.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 |
21 | from __future__ import absolute_import, print_function
22 | import numpy as np
23 | import sys
24 | from six.moves import range
25 |
26 | __author__ = 'Fabian Paul'
27 | __all__ = ['acf']
28 |
29 |
30 | def acf(trajs, stride=1, max_lag=None, subtract_mean=True, normalize=True, mean=None):
31 | '''Computes the (combined) autocorrelation function of multiple trajectories.
32 |
33 | Parameters
34 | ----------
35 | trajs : list of (*,N) ndarrays
36 | the observable trajectories, N is the number of observables
37 | stride : int (default = 1)
38 | only take every n'th frame from trajs
39 | max_lag : int (default = maximum trajectory length / stride)
40 | only compute acf up to this lag time
41 | subtract_mean : bool (default = True)
42 | subtract trajectory mean before computing acfs
43 | normalize : bool (default = True)
44 | divide acf be the variance such that acf[0,:]==1
45 | mean : (N) ndarray (optional)
46 | if subtract_mean is True, you can give the trajectory mean
47 | so this functions doesn't have to compute it again
48 |
49 | Returns
50 | -------
51 | acf : (max_lag,N) ndarray
52 | autocorrelation functions for all observables
53 |
54 | Note
55 | ----
56 | The computation uses FFT (with zero-padding) and is done im memory (RAM).
57 | '''
58 | if not isinstance(trajs, list):
59 | trajs = [trajs]
60 |
61 | mytrajs = [None] * len(trajs)
62 | for i in range(len(trajs)):
63 | if trajs[i].ndim == 1:
64 | mytrajs[i] = trajs[i].reshape((trajs[i].shape[0], 1))
65 | elif trajs[i].ndim == 2:
66 | mytrajs[i] = trajs[i]
67 | else:
68 | raise Exception(
69 | 'Unexpected number of dimensions in trajectory number %d' % i)
70 | trajs = mytrajs
71 |
72 | assert stride > 0, 'stride must be > 0'
73 | assert max_lag is None or max_lag > 0, 'max_lag must be > 0'
74 |
75 | if subtract_mean and mean is None:
76 | # compute mean over all trajectories
77 | mean = trajs[0].sum(axis=0)
78 | n_samples = trajs[0].shape[0]
79 | for i, traj in enumerate(trajs[1:]):
80 | if traj.shape[1] != mean.shape[0]:
81 | raise Exception(('number of order parameters in trajectory number %d differs' +
82 | 'from the number found in previous trajectories.') % (i + 1))
83 | mean += traj.sum(axis=0)
84 | n_samples += traj.shape[0]
85 | mean /= n_samples
86 |
87 | res = np.array([[]])
88 | # number of samples for every tau
89 | N = np.array([])
90 |
91 | for i, traj in enumerate(trajs):
92 | data = traj[::stride]
93 | if subtract_mean:
94 | data -= mean
95 | # calc acfs
96 | l = data.shape[0]
97 | fft = np.fft.fft(data, n=2 ** int(np.ceil(np.log2(l * 2 - 1))), axis=0)
98 | acftraj = np.fft.ifft(fft * np.conjugate(fft), axis=0).real
99 | # throw away acf data for long lag times (and negative lag times)
100 | if max_lag and max_lag < l:
101 | acftraj = acftraj[:max_lag, :]
102 | else:
103 | acftraj = acftraj[:l, :]
104 | if max_lag:
105 | sys.stderr.write(
106 | 'Warning: trajectory number %d is shorter than maximum lag.\n' % i)
107 | # find number of samples used for every lag
108 | Ntraj = np.linspace(l, l - acftraj.shape[0] + 1, acftraj.shape[0])
109 | # adapt shape of acf: resize temporal dimension, additionally set
110 | # number of order parameters of acf in the first step
111 | if res.shape[1] < acftraj.shape[1] and res.shape[1] > 0:
112 | raise Exception(('number of order parameters in trajectory number %d differs ' +
113 | 'from the number found in previous trajectories.') % i)
114 | if res.shape[1] < acftraj.shape[1] or res.shape[0] < acftraj.shape[0]:
115 | res = np.resize(res, acftraj.shape)
116 | N = np.resize(N, acftraj.shape[0])
117 | # update acf and number of samples
118 | res[0:acftraj.shape[0], :] += acftraj
119 | N[0:acftraj.shape[0]] += Ntraj
120 |
121 | # divide by number of samples
122 | res = np.transpose(np.transpose(res) / N)
123 |
124 | # normalize acfs
125 | if normalize:
126 | res /= res[0, :].copy()
127 |
128 | return res
--------------------------------------------------------------------------------
/chainsaw/util/units.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 |
21 | from __future__ import absolute_import
22 | __author__ = 'noe'
23 |
24 | import numpy as np
25 | import math
26 |
27 | class TimeUnit(object):
28 |
29 | _UNIT_STEP = -1
30 | _UNIT_FS = 0
31 | _UNIT_PS = 1
32 | _UNIT_NS = 2
33 | _UNIT_US = 3
34 | _UNIT_MS = 4
35 | _UNIT_S = 5
36 | _unit_names = ['fs','ps','ns','us','ms','s']
37 |
38 | def __init__(self, unit = '1 step'):
39 | """
40 | Initializes the time unit object
41 |
42 | Parameters
43 | ----------
44 | unit : str
45 | Description of a physical time unit. By default '1 step', i.e. there is no physical time unit.
46 | Specify by a number, whitespace and unit. Permitted units are (* is an arbitrary string):
47 | 'fs', 'femtosecond*'
48 | 'ps', 'picosecond*'
49 | 'ns', 'nanosecond*'
50 | 'us', 'microsecond*'
51 | 'ms', 'millisecond*'
52 | 's', 'second*'
53 |
54 | """
55 | if isinstance(unit, TimeUnit): # copy constructor
56 | self._factor = unit._factor
57 | self._unit = unit._unit
58 | else: # construct from string
59 | lunit = unit.lower()
60 | words = lunit.split(' ')
61 |
62 | if len(words) == 1:
63 | self._factor = 1.0
64 | unitstring = words[0]
65 | elif len(words) == 2:
66 | self._factor = float(words[0])
67 | unitstring = words[1]
68 | else:
69 | raise ValueError('Illegal input string: '+str(unit))
70 |
71 | if unitstring == 'step':
72 | self._unit = self._UNIT_STEP
73 | elif unitstring == 'fs' or unitstring.startswith('femtosecond'):
74 | self._unit = self._UNIT_FS
75 | elif unitstring == 'ps' or unitstring.startswith('picosecond'):
76 | self._unit = self._UNIT_PS
77 | elif unitstring == 'ns' or unitstring.startswith('nanosecond'):
78 | self._unit = self._UNIT_NS
79 | elif unitstring == 'us' or unitstring.startswith('microsecond'):
80 | self._unit = self._UNIT_US
81 | elif unitstring == 'ms' or unitstring.startswith('millisecond'):
82 | self._unit = self._UNIT_MS
83 | elif unitstring == 's' or unitstring.startswith('second'):
84 | self._unit = self._UNIT_S
85 | else:
86 | raise ValueError('Time unit is not understood: '+unit)
87 |
88 | def __str__(self):
89 | if self._unit == -1:
90 | return str(self._factor)+' step'
91 | else:
92 | return str(self._factor)+' '+self._unit_names[self._unit]
93 |
94 | @property
95 | def dt(self):
96 | return self._factor
97 |
98 | @property
99 | def unit(self):
100 | return self._unit
101 |
102 | def get_scaled(self, factor):
103 | """ Get a new time unit, scaled by the given factor """
104 | import copy
105 | res = copy.deepcopy(self)
106 | res._factor *= factor
107 | return res
108 |
109 | def rescale_around1(self, times):
110 | """
111 | Suggests a rescaling factor and new physical time unit to balance the given time multiples around 1.
112 |
113 | Parameters
114 | ----------
115 | times : float array
116 | array of times in multiple of the present elementary unit
117 |
118 | """
119 | if self._unit == self._UNIT_STEP:
120 | return times, 'step' # nothing to do
121 |
122 | m = np.mean(times)
123 | mult = 1.0
124 | cur_unit = self._unit
125 |
126 | # numbers are too small. Making them larger and reducing the unit:
127 | if (m < 0.001):
128 | while mult*m < 0.001 and cur_unit >= 0:
129 | mult *= 1000
130 | cur_unit -= 1
131 | return mult*times, self._unit_names[cur_unit]
132 |
133 | # numbers are too large. Making them smaller and increasing the unit:
134 | if (m > 1000):
135 | while mult*m > 1000 and cur_unit <= 5:
136 | mult /= 1000
137 | cur_unit += 1
138 | return mult*times, self._unit_names[cur_unit]
139 |
140 | # nothing to do
141 | return times, self._unit
142 |
143 | def bytes_to_string(num, suffix='B'):
144 | """
145 | Returns the size of num (bytes) in a human readable form up to Yottabytes (YB).
146 | :param num: The size of interest in bytes.
147 | :param suffix: A suffix, default 'B' for 'bytes'.
148 | :return: a human readable representation of a size in bytes
149 | """
150 | extensions = ["%s%s" % (x, suffix) for x in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y']]
151 | if num == 0:
152 | return "0%s" % extensions[0]
153 | else:
154 | n_bytes = float(abs(num))
155 | place = int(math.floor(math.log(n_bytes, 1024)))
156 | return "%.1f%s" % (np.sign(num) * (n_bytes / pow(1024, place)), extensions[place])
157 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_coordinates_iterator.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import numpy as np
3 |
4 | from chainsaw.data import DataInMemory
5 | from chainsaw.util.files import TemporaryDirectory
6 | import os
7 | from glob import glob
8 |
9 |
10 | class TestCoordinatesIterator(unittest.TestCase):
11 |
12 | @classmethod
13 | def setUpClass(cls):
14 | cls.d = [np.random.random((100, 3)) for _ in range(3)]
15 |
16 | def test_current_trajindex(self):
17 | r = DataInMemory(self.d)
18 | expected_itraj = 0
19 | for itraj, X in r.iterator(chunk=0):
20 | assert itraj == expected_itraj
21 | expected_itraj += 1
22 |
23 | expected_itraj = -1
24 | it = r.iterator(chunk=16)
25 | for itraj, X in it:
26 | if it.pos == 0:
27 | expected_itraj += 1
28 | assert itraj == expected_itraj == it.current_trajindex
29 |
30 | def test_n_chunks(self):
31 | r = DataInMemory(self.d)
32 |
33 | it0 = r.iterator(chunk=0)
34 | assert it0._n_chunks == 3 # 3 trajs
35 |
36 | it1 = r.iterator(chunk=50)
37 | assert it1._n_chunks == 3 * 2 # 2 chunks per trajectory
38 |
39 | it2 = r.iterator(chunk=30)
40 | # 3 full chunks and 1 small chunk per trajectory
41 | assert it2._n_chunks == 3 * 4
42 |
43 | it3 = r.iterator(chunk=30)
44 | it3.skip = 10
45 | assert it3._n_chunks == 3 * 3 # 3 full chunks per traj
46 |
47 | it4 = r.iterator(chunk=30)
48 | it4.skip = 5
49 | # 3 full chunks and 1 chunk of 5 frames per trajectory
50 | assert it4._n_chunks == 3 * 4
51 |
52 | # test for lagged iterator
53 | for stride in range(1, 5):
54 | for lag in range(0, 18):
55 | it = r.iterator(
56 | lag=lag, chunk=30, stride=stride, return_trajindex=False)
57 | chunks = 0
58 | for _ in it:
59 | chunks += 1
60 | assert chunks == it._n_chunks
61 |
62 | def test_skip(self):
63 | r = DataInMemory(self.d)
64 | lagged_it = r.iterator(lag=5)
65 | assert lagged_it._it.skip == 0
66 | assert lagged_it._it_lagged.skip == 5
67 |
68 | it = r.iterator()
69 | for itraj, X in it:
70 | if itraj == 0:
71 | it.skip = 5
72 | if itraj == 1:
73 | assert it.skip == 5
74 |
75 | def test_chunksize(self):
76 | r = DataInMemory(self.d)
77 | cs = np.arange(1, 17)
78 | i = 0
79 | it = r.iterator(chunk=cs[i])
80 | for itraj, X in it:
81 | if not it.last_chunk_in_traj:
82 | assert len(X) == it.chunksize
83 | else:
84 | assert len(X) <= it.chunksize
85 | i += 1
86 | i %= len(cs)
87 | it.chunksize = cs[i]
88 | assert it.chunksize == cs[i]
89 |
90 | def test_last_chunk(self):
91 | r = DataInMemory(self.d)
92 | it = r.iterator(chunk=0)
93 | for itraj, X in it:
94 | assert it.last_chunk_in_traj
95 | if itraj == 2:
96 | assert it.last_chunk
97 |
98 | def test_stride(self):
99 | r = DataInMemory(self.d)
100 | stride = np.arange(1, 17)
101 | i = 0
102 | it = r.iterator(stride=stride[i], chunk=1)
103 | for _ in it:
104 | i += 1
105 | i %= len(stride)
106 | it.stride = stride[i]
107 | assert it.stride == stride[i]
108 |
109 | def test_return_trajindex(self):
110 | r = DataInMemory(self.d)
111 | it = r.iterator(chunk=0)
112 | it.return_traj_index = True
113 | assert it.return_traj_index is True
114 | for tup in it:
115 | assert len(tup) == 2
116 | it.reset()
117 | it.return_traj_index = False
118 | assert it.return_traj_index is False
119 | itraj = 0
120 | for tup in it:
121 | np.testing.assert_equal(tup, self.d[itraj])
122 | itraj += 1
123 |
124 | for tup in r.iterator(return_trajindex=True):
125 | assert len(tup) == 2
126 | itraj = 0
127 | for tup in r.iterator(return_trajindex=False):
128 | np.testing.assert_equal(tup, self.d[itraj])
129 | itraj += 1
130 |
131 | def test_pos(self):
132 | r = DataInMemory(self.d)
133 | r.chunksize = 17
134 | it = r.iterator()
135 | t = 0
136 | for itraj, X in it:
137 | assert t == it.pos
138 | t += len(X)
139 | if it.last_chunk_in_traj:
140 | t = 0
141 |
142 | def test_write_to_csv_propagate_filenames(self):
143 | from chainsaw import source, tica
144 | with TemporaryDirectory() as td:
145 | data = [np.random.random((20, 3))] * 3
146 | fns = [os.path.join(td, f)
147 | for f in ('blah.npy', 'blub.npy', 'foo.npy')]
148 | for x, fn in zip(data, fns):
149 | np.save(fn, x)
150 | reader = source(fns)
151 | assert reader.filenames == fns
152 | tica_obj = tica(reader, lag=1, dim=2)
153 | tica_obj.write_to_csv(extension=".txt", chunksize=3)
154 | res = sorted([os.path.abspath(x) for x in glob(td + os.path.sep + '*.txt')])
155 | self.assertEqual(len(res), len(fns))
156 | desired_fns = sorted([s.replace('.npy', '.txt') for s in fns])
157 | self.assertEqual(res, desired_fns)
158 |
159 | # compare written results
160 | expected = tica_obj.get_output()
161 | actual = source(list(s.replace('.npy', '.txt') for s in fns)).get_output()
162 | assert len(actual) == len(fns)
163 | for a, e in zip(actual, expected):
164 | np.testing.assert_allclose(a, e)
165 |
166 | if __name__ == '__main__':
167 | unittest.main()
168 |
--------------------------------------------------------------------------------
/chainsaw/util/reflection.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 | from __future__ import division, print_function, absolute_import
20 |
21 | import inspect
22 | import six
23 | #from six import string_types
24 | from collections import namedtuple
25 |
26 | __author__ = 'noe, marscher'
27 |
28 |
29 | # Add a replacement for inspect.getargspec() which is deprecated in python 3.5
30 | # The version below is borrowed from Django,
31 | # https://github.com/django/django/pull/4846
32 |
33 | # Note an inconsistency between inspect.getargspec(func) and
34 | # inspect.signature(func). If `func` is a bound method, the latter does *not*
35 | # list `self` as a first argument, while the former *does*.
36 | # Hence cook up a common ground replacement: `getargspec_no_self` which
37 | # mimics `inspect.getargspec` but does not list `self`.
38 | #
39 | # This way, the caller code does not need to know whether it uses a legacy
40 | # .getargspec or bright and shiny .signature.
41 |
42 | try:
43 | # is it python 3.3 or higher?
44 | inspect.signature
45 |
46 | # Apparently, yes. Wrap inspect.signature
47 |
48 | ArgSpec = namedtuple('ArgSpec', ['args', 'varargs', 'keywords', 'defaults'])
49 |
50 | def getargspec_no_self(func):
51 | """inspect.getargspec replacement using inspect.signature.
52 |
53 | inspect.getargspec is deprecated in python 3. This is a replacement
54 | based on the (new in python 3.3) `inspect.signature`.
55 |
56 | Parameters
57 | ----------
58 | func : callable
59 | A callable to inspect
60 |
61 | Returns
62 | -------
63 | argspec : ArgSpec(args, varargs, varkw, defaults)
64 | This is similar to the result of inspect.getargspec(func) under
65 | python 2.x.
66 | NOTE: if the first argument of `func` is self, it is *not*, I repeat
67 | *not* included in argspec.args.
68 | This is done for consistency between inspect.getargspec() under
69 | python 2.x, and inspect.signature() under python 3.x.
70 | """
71 | sig = inspect.signature(func)
72 | args = [
73 | p.name for p in sig.parameters.values()
74 | if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
75 | ]
76 | varargs = [
77 | p.name for p in sig.parameters.values()
78 | if p.kind == inspect.Parameter.VAR_POSITIONAL
79 | ]
80 | varargs = varargs[0] if varargs else None
81 | varkw = [
82 | p.name for p in sig.parameters.values()
83 | if p.kind == inspect.Parameter.VAR_KEYWORD
84 | ]
85 | varkw = varkw[0] if varkw else None
86 | defaults = [
87 | p.default for p in sig.parameters.values()
88 | if (p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD and
89 | p.default is not p.empty)
90 | ] or None
91 |
92 | if args[0] == 'self':
93 | args.pop(0)
94 |
95 | return ArgSpec(args, varargs, varkw, defaults)
96 |
97 | except AttributeError:
98 | # python 2.x
99 | def getargspec_no_self(func):
100 | """inspect.getargspec replacement for compatibility with python 3.x.
101 |
102 | inspect.getargspec is deprecated in python 3. This wraps it, and
103 | *removes* `self` from the argument list of `func`, if present.
104 | This is done for forward compatibility with python 3.
105 |
106 | Parameters
107 | ----------
108 | func : callable
109 | A callable to inspect
110 |
111 | Returns
112 | -------
113 | argspec : ArgSpec(args, varargs, varkw, defaults)
114 | This is similar to the result of inspect.getargspec(func) under
115 | python 2.x.
116 | NOTE: if the first argument of `func` is self, it is *not*, I repeat
117 | *not* included in argspec.args.
118 | This is done for consistency between inspect.getargspec() under
119 | python 2.x, and inspect.signature() under python 3.x.
120 | """
121 | argspec = inspect.getargspec(func)
122 | if argspec.args[0] == 'self':
123 | argspec.args.pop(0)
124 | return argspec
125 |
126 |
127 | def call_member(obj, f, *args, **kwargs):
128 | """ Calls the specified method, property or attribute of the given object
129 |
130 | Parameters
131 | ----------
132 | obj : object
133 | The object that will be used
134 | f : str or function
135 | Name of or reference to method, property or attribute
136 | failfast : bool
137 | If True, will raise an exception when trying a method that doesn't exist. If False, will simply return None
138 | in that case
139 | """
140 | # get function name
141 | if not isinstance(f, six.string_types):
142 | fname = f.__func__.__name__
143 | else:
144 | fname = f
145 | # get the method ref
146 | method = getattr(obj, fname)
147 | # handle cases
148 | if inspect.ismethod(method):
149 | return method(*args, **kwargs)
150 |
151 | # attribute or property
152 | return method
153 |
154 |
155 | def get_default_args(func):
156 | """
157 | returns a dictionary of arg_name:default_values for the input function
158 | """
159 | args, varargs, keywords, defaults = getargspec_no_self(func)
160 | return dict(zip(args[-len(defaults):], defaults))
161 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_featurereader_and_tica_projection.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Test feature reader and Tica by checking the properties of the ICs.
22 | cov(ic_i,ic_j) = delta_ij and cov(ic_i,ic_j,tau) = lambda_i delta_ij
23 | @author: Fabian Paul
24 | '''
25 |
26 | from __future__ import absolute_import
27 | from __future__ import print_function
28 |
29 | import os
30 | import tempfile
31 | import unittest
32 | from logging import getLogger
33 |
34 | import mdtraj
35 | import numpy as np
36 | from chainsaw.api import tica
37 | from chainsaw.data.md.feature_reader import FeatureReader
38 | from chainsaw.util.contexts import numpy_random_seed
39 | from nose.plugins.attrib import attr
40 | from six.moves import range
41 |
42 | log = getLogger('chainsaw.'+'TestFeatureReaderAndTICAProjection')
43 |
44 |
45 | def random_invertible(n, eps=0.01):
46 | 'generate real random invertible matrix'
47 | m = np.random.randn(n, n)
48 | u, s, v = np.linalg.svd(m)
49 | s = np.maximum(s, eps)
50 | return u.dot(np.diag(s)).dot(v)
51 |
52 |
53 | @attr(slow=True)
54 | class TestFeatureReaderAndTICAProjection(unittest.TestCase):
55 | @classmethod
56 | def setUpClass(cls):
57 | with numpy_random_seed(52):
58 | c = super(TestFeatureReaderAndTICAProjection, cls).setUpClass()
59 |
60 | cls.dim = 99 # dimension (must be divisible by 3)
61 | N = 5000 # length of single trajectory # 500000 # 50000
62 | N_trajs = 10 # number of trajectories
63 |
64 | A = random_invertible(cls.dim) # mixing matrix
65 | # tica will approximate its inverse with the projection matrix
66 | mean = np.random.randn(cls.dim)
67 |
68 | # create topology file
69 | cls.temppdb = tempfile.mktemp('.pdb')
70 | with open(cls.temppdb, 'w') as f:
71 | for i in range(cls.dim // 3):
72 | print(('ATOM %5d C ACE A 1 28.490 31.600 33.379 0.00 1.00' % i), file=f)
73 |
74 | t = np.arange(0, N)
75 | cls.trajnames = [] # list of xtc file names
76 | for i in range(N_trajs):
77 | # set up data
78 | white = np.random.randn(N, cls.dim)
79 | brown = np.cumsum(white, axis=0)
80 | correlated = np.dot(brown, A)
81 | data = correlated + mean
82 | xyz = data.reshape((N, cls.dim // 3, 3))
83 | # create trajectory file
84 | traj = mdtraj.load(cls.temppdb)
85 | traj.xyz = xyz
86 | traj.time = t
87 | tempfname = tempfile.mktemp('.xtc')
88 | traj.save(tempfname)
89 | cls.trajnames.append(tempfname)
90 |
91 | @classmethod
92 | def tearDownClass(cls):
93 | for fname in cls.trajnames:
94 | os.unlink(fname)
95 | os.unlink(cls.temppdb)
96 | super(TestFeatureReaderAndTICAProjection, cls).tearDownClass()
97 |
98 | def test_covariances_and_eigenvalues(self):
99 | reader = FeatureReader(self.trajnames, self.temppdb)
100 | for tau in [1, 10, 100, 1000, 2000]:
101 | trans = tica(lag=tau, dim=self.dim, kinetic_map=False)
102 | trans.data_producer = reader
103 |
104 | log.info('number of trajectories reported by tica %d' % trans.number_of_trajectories())
105 | trans.parametrize()
106 | data = trans.get_output()
107 |
108 | log.info('max. eigenvalue: %f' % np.max(trans.eigenvalues))
109 | self.assertTrue(np.all(trans.eigenvalues <= 1.0))
110 | # check ICs
111 | check = tica(data=data, lag=tau, dim=self.dim)
112 |
113 | np.testing.assert_allclose(np.eye(self.dim), check.cov, atol=1e-8)
114 | np.testing.assert_allclose(check.mean, 0.0, atol=1e-8)
115 | ic_cov_tau = np.zeros((self.dim, self.dim))
116 | ic_cov_tau[np.diag_indices(self.dim)] = trans.eigenvalues
117 | np.testing.assert_allclose(ic_cov_tau, check.cov_tau, atol=1e-8)
118 |
119 | def test_partial_fit(self):
120 | from chainsaw import source
121 | reader = source(self.trajnames, top=self.temppdb)
122 | reader_output = reader.get_output()
123 |
124 | params = {'lag': 10, 'kinetic_map': False, 'dim': self.dim}
125 |
126 | tica_obj = tica(**params)
127 | tica_obj.partial_fit(reader_output[0])
128 | assert not tica_obj._estimated
129 | # acccess eigenvectors to force diagonalization
130 | tica_obj.eigenvectors
131 | assert tica_obj._estimated
132 |
133 | tica_obj.partial_fit(reader_output[1])
134 | assert not tica_obj._estimated
135 |
136 | tica_obj.eigenvalues
137 | assert tica_obj._estimated
138 |
139 | for traj in reader_output[2:]:
140 | tica_obj.partial_fit(traj)
141 |
142 | # reference
143 | ref = tica(reader, **params)
144 |
145 | np.testing.assert_allclose(tica_obj.cov, ref.cov, atol=1e-15)
146 | np.testing.assert_allclose(tica_obj.cov_tau, ref.cov_tau, atol=1e-15)
147 |
148 | np.testing.assert_allclose(tica_obj.eigenvalues, ref.eigenvalues, atol=1e-15)
149 | # we do not test eigenvectors here, since the system is very metastable and
150 | # we have multiple eigenvalues very close to one.
151 |
152 | if __name__ == "__main__":
153 | unittest.main()
154 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_cluster.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 |
21 | from __future__ import absolute_import
22 | import unittest
23 | import os
24 | import tempfile
25 |
26 | import numpy as np
27 |
28 | from logging import getLogger
29 | import chainsaw as coor
30 | from chainsaw.util import types
31 | from six.moves import range
32 |
33 |
34 | logger = getLogger('chainsaw.'+'TestReaderUtils')
35 |
36 |
37 | class TestCluster(unittest.TestCase):
38 |
39 | @classmethod
40 | def setUpClass(cls):
41 | super(TestCluster, cls).setUpClass()
42 | cls.dtraj_dir = tempfile.mkdtemp()
43 |
44 | # generate Gaussian mixture
45 | means = [np.array([-3,0]),
46 | np.array([-1,1]),
47 | np.array([0,0]),
48 | np.array([1,-1]),
49 | np.array([4,2])]
50 | widths = [np.array([0.3,2]),
51 | np.array([0.3,2]),
52 | np.array([0.3,2]),
53 | np.array([0.3,2]),
54 | np.array([0.3,2])]
55 | # continuous trajectory
56 | nsample = 1000
57 | cls.T = len(means)*nsample
58 | cls.X = np.zeros((cls.T, 2))
59 | for i in range(len(means)):
60 | cls.X[i*nsample:(i+1)*nsample,0] = widths[i][0] * np.random.randn() + means[i][0]
61 | cls.X[i*nsample:(i+1)*nsample,1] = widths[i][1] * np.random.randn() + means[i][1]
62 | # cluster in different ways
63 | cls.km = coor.cluster_kmeans(data = cls.X, k = 100)
64 | cls.rs = coor.cluster_regspace(data = cls.X, dmin=0.5)
65 | cls.rt = coor.cluster_uniform_time(data = cls.X, k = 100)
66 | cls.cl = [cls.km, cls.rs, cls.rt]
67 |
68 | def setUp(self):
69 | pass
70 |
71 | def test_chunksize(self):
72 | for c in self.cl:
73 | assert types.is_int(c.chunksize)
74 |
75 | def test_clustercenters(self):
76 | for c in self.cl:
77 | assert c.clustercenters.shape[0] == c.n_clusters
78 | assert c.clustercenters.shape[1] == 2
79 |
80 | def test_data_producer(self):
81 | for c in self.cl:
82 | assert c.data_producer is not None
83 |
84 | def test_describe(self):
85 | for c in self.cl:
86 | desc = c.describe()
87 | assert types.is_string(desc) or types.is_list_of_string(desc)
88 |
89 | def test_dimension(self):
90 | for c in self.cl:
91 | assert types.is_int(c.dimension())
92 | assert c.dimension() == 1
93 |
94 | def test_dtrajs(self):
95 | for c in self.cl:
96 | assert len(c.dtrajs) == 1
97 | assert c.dtrajs[0].dtype == c.output_type()
98 | assert len(c.dtrajs[0]) == self.T
99 |
100 | def test_get_output(self):
101 | for c in self.cl:
102 | O = c.get_output()
103 | assert types.is_list(O)
104 | assert len(O) == 1
105 | assert types.is_int_matrix(O[0])
106 | assert O[0].shape[0] == self.T
107 | assert O[0].shape[1] == 1
108 |
109 | def test_in_memory(self):
110 | for c in self.cl:
111 | assert isinstance(c.in_memory, bool)
112 |
113 | def test_iterator(self):
114 | for c in self.cl:
115 | for itraj, chunk in c:
116 | assert types.is_int(itraj)
117 | assert types.is_int_matrix(chunk)
118 | assert chunk.shape[0] <= c.chunksize or c.chunksize == 0
119 | assert chunk.shape[1] == c.dimension()
120 |
121 | def test_map(self):
122 | for c in self.cl:
123 | Y = c.transform(self.X)
124 | assert Y.shape[0] == self.T
125 | assert Y.shape[1] == 1
126 | # test if consistent with get_output
127 | assert np.allclose(Y, c.get_output()[0])
128 |
129 | def test_n_frames_total(self):
130 | for c in self.cl:
131 | c.n_frames_total() == self.T
132 |
133 | def test_number_of_trajectories(self):
134 | for c in self.cl:
135 | c.number_of_trajectories() == 1
136 |
137 | def test_output_type(self):
138 | for c in self.cl:
139 | assert c.output_type() == np.int32
140 |
141 | def test_parametrize(self):
142 | for c in self.cl:
143 | # nothing should happen
144 | c.parametrize()
145 |
146 | def test_save_dtrajs(self):
147 | extension = ".dtraj"
148 | outdir = self.dtraj_dir
149 | for c in self.cl:
150 | prefix = "test_save_dtrajs_%s" % type(c).__name__
151 | c.save_dtrajs(trajfiles=None, prefix=prefix, output_dir=outdir, extension=extension)
152 |
153 | names = ["%s_%i%s" % (prefix, i, extension)
154 | for i in range(c.data_producer.number_of_trajectories())]
155 | names = [os.path.join(outdir, n) for n in names]
156 |
157 | # check files with given patterns are there
158 | for f in names:
159 | os.stat(f)
160 |
161 | def test_trajectory_length(self):
162 | for c in self.cl:
163 | assert c.trajectory_length(0) == self.T
164 | with self.assertRaises(IndexError):
165 | c.trajectory_length(1)
166 |
167 | def test_trajectory_lengths(self):
168 | for c in self.cl:
169 | assert len(c.trajectory_lengths()) == 1
170 | assert c.trajectory_lengths()[0] == c.trajectory_length(0)
171 |
172 |
173 | if __name__ == "__main__":
174 | unittest.main()
--------------------------------------------------------------------------------
/chainsaw/data/_base/random_accessible.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 |
3 | import numpy as np
4 | import numbers
5 |
6 | import six
7 |
8 |
9 | class NotRandomAccessibleException(Exception):
10 | pass
11 |
12 |
13 | class TrajectoryRandomAccessible(object):
14 | def __init__(self):
15 | self._ra_cuboid = NotImplementedRandomAccessStrategy(self)
16 | self._ra_linear_strategy = NotImplementedRandomAccessStrategy(self)
17 | self._ra_linear_itraj_strategy = NotImplementedRandomAccessStrategy(self)
18 | self._ra_jagged = NotImplementedRandomAccessStrategy(self)
19 | self._is_random_accessible = False
20 |
21 | @property
22 | def is_random_accessible(self):
23 | """
24 | Check if self._is_random_accessible is set to true and if all the random access strategies are implemented.
25 | Returns
26 | -------
27 | bool : Returns True if random accessible via strategies and False otherwise.
28 | """
29 | return self._is_random_accessible and \
30 | not isinstance(self.ra_itraj_cuboid, NotImplementedRandomAccessStrategy) and \
31 | not isinstance(self.ra_linear, NotImplementedRandomAccessStrategy) and \
32 | not isinstance(self.ra_itraj_jagged, NotImplementedRandomAccessStrategy) and \
33 | not isinstance(self.ra_itraj_linear, NotImplementedRandomAccessStrategy)
34 |
35 | @property
36 | def ra_itraj_cuboid(self):
37 | """
38 | Implementation of random access with slicing that can be up to 3-dimensional, where the first dimension
39 | corresponds to the trajectory index, the second dimension corresponds to the frames and the third dimension
40 | corresponds to the dimensions of the frames.
41 |
42 | The with the frame slice selected frames will be loaded from each in the trajectory-slice selected trajectories
43 | and then sliced with the dimension slice. For example: The data consists out of three trajectories with length
44 | 10, 20, 10, respectively. The slice `data[:, :15, :3]` returns a 3D array of shape (3, 10, 3), where the first
45 | component corresponds to the three trajectories, the second component corresponds to 10 frames (note that
46 | the last 5 frames are being truncated as the other two trajectories only have 10 frames) and the third component
47 | corresponds to the selected first three dimensions.
48 |
49 | :return: Returns an object that allows access by slices in the described manner.
50 | """
51 | if not self._is_random_accessible:
52 | raise NotRandomAccessibleException()
53 | return self._ra_cuboid
54 |
55 | @property
56 | def ra_itraj_jagged(self):
57 | """
58 | Behaves like ra_itraj_cuboid just that the trajectories are not truncated and returned as a list.
59 |
60 | :return: Returns an object that allows access by slices in the described manner.
61 | """
62 | if not self._is_random_accessible:
63 | raise NotRandomAccessibleException()
64 | return self._ra_jagged
65 |
66 | @property
67 | def ra_linear(self):
68 | """
69 | Implementation of random access that takes a (maximal) two-dimensional slice where the first component
70 | corresponds to the frames and the second component corresponds to the dimensions. Here it is assumed that
71 | the frame indexing is contiguous, i.e., the first frame of the second trajectory has the index of the last frame
72 | of the first trajectory plus one.
73 |
74 | :return: Returns an object that allows access by slices in the described manner.
75 | """
76 | if not self._is_random_accessible:
77 | raise NotRandomAccessibleException()
78 | return self._ra_linear_strategy
79 |
80 | @property
81 | def ra_itraj_linear(self):
82 | """
83 | Implementation of random access that takes arguments as the default random access (i.e., up to three dimensions
84 | with trajs, frames and dims, respectively), but which considers the frame indexing to be contiguous. Therefore,
85 | it returns a simple 2D array.
86 |
87 | :return: A 2D array of the sliced data containing [frames, dims].
88 | """
89 | if not self._is_random_accessible:
90 | raise NotRandomAccessibleException()
91 | return self._ra_linear_itraj_strategy
92 |
93 |
94 | class RandomAccessStrategy(six.with_metaclass(ABCMeta)):
95 | """
96 | Abstract parent class for all random access strategies. It holds its corresponding data source and
97 | implements `__getitem__` as well as `__getslice__`, which both get delegated to `_handle_slice`.
98 | """
99 | def __init__(self, source, max_slice_dimension=-1):
100 | self._source = source
101 | self._max_slice_dimension = max_slice_dimension
102 |
103 | @abstractmethod
104 | def _handle_slice(self, idx):
105 | pass
106 |
107 | @property
108 | def max_slice_dimension(self):
109 | """
110 | Property that returns how many dimensions the slice can have.
111 | Returns
112 | -------
113 | int : the maximal slice dimension
114 | """
115 | return self._max_slice_dimension
116 |
117 | def __getitem__(self, idx):
118 | return self._handle_slice(idx)
119 |
120 | def __getslice__(self, start, stop):
121 | """For slices of the form data[1:3]."""
122 | return self.__getitem__(slice(start, stop))
123 |
124 | def _get_indices(self, item, length):
125 | if isinstance(item, slice):
126 | item = np.array(range(*item.indices(length)))
127 | elif not isinstance(item, np.ndarray):
128 | if isinstance(item, list):
129 | item = np.array(item)
130 | else:
131 | item = np.arange(0, length)[item]
132 | if isinstance(item, numbers.Integral):
133 | item = np.array([item])
134 | return item
135 |
136 | def _max(self, elems):
137 | if isinstance(elems, numbers.Integral):
138 | elems = [elems]
139 | return max(elems)
140 |
141 |
142 | class NotImplementedRandomAccessStrategy(RandomAccessStrategy):
143 | def _handle_slice(self, idx):
144 | raise NotImplementedError("Requested random access strategy is not implemented for the current data source.")
145 |
--------------------------------------------------------------------------------
/chainsaw/util/linalg.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | from __future__ import absolute_import
21 | import numpy as np
22 | import scipy.linalg
23 | import scipy.sparse
24 | import copy
25 | import math
26 | from six.moves import range
27 |
28 | __author__ = 'noe'
29 |
30 |
31 | def mdot(*args):
32 | """Computes a matrix product of multiple ndarrays
33 |
34 | This is a convenience function to avoid constructs such as np.dot(A, np.dot(B, np.dot(C, D))) and instead
35 | use mdot(A, B, C, D).
36 |
37 | Parameters
38 | ----------
39 | *args : an arbitrarily long list of ndarrays that must be compatible for multiplication,
40 | i.e. args[i].shape[1] = args[i+1].shape[0].
41 | """
42 | if len(args) < 1:
43 | raise ValueError('need at least one argument')
44 | elif len(args) == 1:
45 | return args[0]
46 | elif len(args) == 2:
47 | return np.dot(args[0],args[1])
48 | else:
49 | return np.dot(args[0], mdot(*args[1:]))
50 |
51 |
52 | def submatrix(M, sel):
53 | """Returns a submatrix of the quadratic matrix M, given by the selected columns and row
54 |
55 | Parameters
56 | ----------
57 | M : ndarray(n,n)
58 | symmetric matrix
59 | sel : int-array
60 | selection of rows and columns. Element i,j will be selected if both are in sel.
61 |
62 | Returns
63 | -------
64 | S : ndarray(m,m)
65 | submatrix with m=len(sel)
66 |
67 | """
68 | assert len(M.shape) == 2, 'M is not a matrix'
69 | assert M.shape[0] == M.shape[1], 'M is not quadratic'
70 |
71 | """Row slicing"""
72 | if scipy.sparse.issparse(M):
73 | C_cc = M.tocsr()
74 | else:
75 | C_cc = M
76 | C_cc = C_cc[sel, :]
77 |
78 | """Column slicing"""
79 | if scipy.sparse.issparse(M):
80 | C_cc = C_cc.tocsc()
81 | C_cc = C_cc[:, sel]
82 |
83 | if scipy.sparse.issparse(M):
84 | return C_cc.tocoo()
85 | else:
86 | return C_cc
87 |
88 |
89 | def _sort_by_norm(evals, evecs):
90 | """
91 | Sorts the eigenvalues and eigenvectors by descending norm of the eigenvalues
92 |
93 | Parameters
94 | ----------
95 | evals: ndarray(n)
96 | eigenvalues
97 | evecs: ndarray(n,n)
98 | eigenvectors in a column matrix
99 |
100 | Returns
101 | -------
102 | (evals, evecs) : ndarray(m), ndarray(n,m)
103 | the sorted eigenvalues and eigenvectors
104 |
105 | """
106 | # norms
107 | evnorms = np.abs(evals)
108 | # sort
109 | I = np.argsort(evnorms)[::-1]
110 | # permute
111 | evals2 = evals[I]
112 | evecs2 = evecs[:, I]
113 | # done
114 | return (evals2, evecs2)
115 |
116 |
117 | def eig_corr(C0, Ct, epsilon=1e-6):
118 | r""" Solve generalized eigenvalues problem with correlation matrices C0 and Ct
119 |
120 | Numerically robust solution of a generalized eigenvalue problem of the form
121 |
122 | .. math::
123 | \mathbf{C}_t \mathbf{r}_i = \mathbf{C}_0 \mathbf{r}_i l_i
124 |
125 | Computes :math:`m` dominant eigenvalues :math:`l_i` and eigenvectors :math:`\mathbf{r}_i`, where
126 | :math:`m` is the numerical rank of the problem. This is done by first conducting a Schur decomposition
127 | of the symmetric positive matrix :math:`\mathbf{C}_0`, then truncating its spectrum to retain only eigenvalues
128 | that are numerically greater than zero, then using this decomposition to define an ordinary eigenvalue
129 | Problem for :math:`\mathbf{C}_t` of size :math:`m`, and then solving this eigenvalue problem.
130 |
131 | Parameters
132 | ----------
133 | C0 : ndarray (n,n)
134 | time-instantaneous correlation matrix. Must be symmetric positive definite
135 | Ct : ndarray (n,n)
136 | time-lagged correlation matrix. Must be symmetric
137 | epsilon : float
138 | eigenvalue norm cutoff. Eigenvalues of C0 with norms <= epsilon will be
139 | cut off. The remaining number of Eigenvalues define the size of
140 | the output.
141 |
142 | Returns
143 | -------
144 | l : ndarray (m)
145 | The first m generalized eigenvalues, sorted by descending norm
146 | R : ndarray (n,m)
147 | The first m generalized eigenvectors, as a column matrix.
148 |
149 | """
150 | # check input
151 | assert np.allclose(C0.T, C0), 'C0 is not a symmetric matrix'
152 | assert np.allclose(Ct.T, Ct), 'Ct is not a symmetric matrix'
153 |
154 | # compute the Eigenvalues of C0 using Schur factorization
155 | (S, V) = scipy.linalg.schur(C0)
156 | s = np.diag(S)
157 | (s, V) = _sort_by_norm(s, V) # sort them
158 |
159 | # determine the cutoff. We know that C0 is an spd matrix,
160 | # so we select the truncation threshold such that everything that is negative vanishes
161 | evmin = np.min(s)
162 | if evmin < 0:
163 | epsilon = max(epsilon, -evmin + 1e-16)
164 |
165 | # determine effective rank m and perform low-rank approximations.
166 | evnorms = np.abs(s)
167 | n = np.shape(evnorms)[0]
168 | m = n - np.searchsorted(evnorms[::-1], epsilon)
169 | Vm = V[:, 0:m]
170 | sm = s[0:m]
171 |
172 | # transform Ct to orthogonal basis given by the eigenvectors of C0
173 | Sinvhalf = 1.0 / np.sqrt(sm)
174 | T = np.dot(np.diag(Sinvhalf), Vm.T)
175 | Ct_trans = np.dot(np.dot(T, Ct), T.T)
176 |
177 | # solve the symmetric eigenvalue problem in the new basis
178 | (l, R_trans) = scipy.linalg.eigh(Ct_trans)
179 | (l, R_trans) = _sort_by_norm(l, R_trans)
180 |
181 | # transform the eigenvectors back to the old basis
182 | R = np.dot(T.T, R_trans)
183 |
184 | # return result
185 | return (l, R)
186 |
--------------------------------------------------------------------------------
/chainsaw/_ext/variational_estimators/tests/benchmark_moments.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | __author__ = 'noe'
4 |
5 | import time
6 | import numpy as np
7 | from .. import moments
8 |
9 | def genS(N):
10 | """ Generates sparsities given N (number of cols) """
11 | S = [10, 90, 100, 500, 900, 1000, 2000, 5000, 7500, 9000, 10000, 20000, 50000, 75000, 90000] # non-zero
12 | return [s for s in S if s <= N]
13 |
14 |
15 | def genX(L, N, n_var=None, const=False):
16 | X = np.random.rand(L, N) # random data
17 | if n_var is not None:
18 | if const:
19 | Xsparse = np.ones((L, N))
20 | else:
21 | Xsparse = np.zeros((L, N))
22 | Xsparse[:, :n_var] = X[:, :n_var]
23 | X = Xsparse
24 | return X
25 |
26 |
27 | def genY(L, N, n_var=None, const=False):
28 | X = np.random.rand(L, N) # random data
29 | if n_var is not None:
30 | if const:
31 | Xsparse = -np.ones((L, N))
32 | else:
33 | Xsparse = np.zeros((L, N))
34 | Xsparse[:, :n_var] = X[:, :n_var]
35 | X = Xsparse
36 | return X
37 |
38 |
39 | def reftime_momentsXX(X, remove_mean=False, nrep=3):
40 | # time for reference calculation
41 | t1 = time.time()
42 | for r in range(nrep):
43 | s_ref = X.sum(axis=0) # computation of mean
44 | if remove_mean:
45 | X = X - s_ref/float(X.shape[0])
46 | C_XX_ref = np.dot(X.T, X) # covariance matrix
47 | t2 = time.time()
48 | # return mean time
49 | return (t2-t1)/float(nrep)
50 |
51 |
52 | def mytime_momentsXX(X, remove_mean=False, nrep=3):
53 | # time for reference calculation
54 | t1 = time.time()
55 | for r in range(nrep):
56 | w, s, C_XX = moments.moments_XX(X, remove_mean=remove_mean)
57 | t2 = time.time()
58 | # return mean time
59 | return (t2-t1)/float(nrep)
60 |
61 |
62 | def reftime_momentsXXXY(X, Y, remove_mean=False, symmetrize=False, nrep=3):
63 | # time for reference calculation
64 | t1 = time.time()
65 | for r in range(nrep):
66 | sx = X.sum(axis=0) # computation of mean
67 | sy = Y.sum(axis=0) # computation of mean
68 | if symmetrize:
69 | sx = 0.5*(sx + sy)
70 | sy = sx
71 | if remove_mean:
72 | X = X - sx/float(X.shape[0])
73 | Y = Y - sy/float(Y.shape[0])
74 | if symmetrize:
75 | C_XX_ref = np.dot(X.T, X) + np.dot(Y.T, Y)
76 | C_XY = np.dot(X.T, Y)
77 | C_XY_ref = C_XY + C_XY.T
78 | else:
79 | C_XX_ref = np.dot(X.T, X)
80 | C_XY_ref = np.dot(X.T, Y)
81 | t2 = time.time()
82 | # return mean time
83 | return (t2-t1)/float(nrep)
84 |
85 |
86 | def mytime_momentsXXXY(X, Y, remove_mean=False, symmetrize=False, nrep=3):
87 | # time for reference calculation
88 | t1 = time.time()
89 | for r in range(nrep):
90 | w, sx, sy, C_XX, C_XY = moments.moments_XXXY(X, Y, remove_mean=remove_mean, symmetrize=symmetrize)
91 | t2 = time.time()
92 | # return mean time
93 | return (t2-t1)/float(nrep)
94 |
95 |
96 | def benchmark_moments(L=10000, N=10000, nrep=5, xy=False, remove_mean=False, symmetrize=False, const=False):
97 | #S = [10, 100, 1000]
98 | S = genS(N)
99 |
100 | # time for reference calculation
101 | X = genX(L, N)
102 | if xy:
103 | Y = genY(L, N)
104 | reftime = reftime_momentsXXXY(X, Y, remove_mean=remove_mean, symmetrize=symmetrize, nrep=nrep)
105 | else:
106 | reftime = reftime_momentsXX(X, remove_mean=remove_mean, nrep=nrep)
107 |
108 | # my time
109 | times = np.zeros(len(S))
110 | for k, s in enumerate(S):
111 | X = genX(L, N, n_var=s, const=const)
112 | if xy:
113 | Y = genY(L, N, n_var=s, const=const)
114 | times[k] = mytime_momentsXXXY(X, Y, remove_mean=remove_mean, symmetrize=symmetrize, nrep=nrep)
115 | else:
116 | times[k] = mytime_momentsXX(X, remove_mean=remove_mean, nrep=nrep)
117 |
118 | # assemble report
119 | rows = ['L, data points', 'N, dimensions', 'S, nonzeros', 'time trivial', 'time moments_XX', 'speed-up']
120 | table = np.zeros((6, len(S)))
121 | table[0, :] = L
122 | table[1, :] = N
123 | table[2, :] = S
124 | table[3, :] = reftime
125 | table[4, :] = times
126 | table[5, :] = reftime / times
127 |
128 | # print table
129 | if xy:
130 | fname = 'moments_XXXY'
131 | else:
132 | fname = 'moments_XX'
133 | print(fname + '\tremove_mean = ' + str(remove_mean) + '\tsym = ' + str(symmetrize) + '\tconst = ' + str(const))
134 | print(rows[0] + ('\t%i' * table.shape[1])%tuple(table[0]))
135 | print(rows[1] + ('\t%i' * table.shape[1])%tuple(table[1]))
136 | print(rows[2] + ('\t%i' * table.shape[1])%tuple(table[2]))
137 | print(rows[3] + ('\t%.3f' * table.shape[1])%tuple(table[3]))
138 | print(rows[4] + ('\t%.3f' * table.shape[1])%tuple(table[4]))
139 | print(rows[5] + ('\t%.3f' * table.shape[1])%tuple(table[5]))
140 | print()
141 |
142 |
143 | def main():
144 | LNs = [(100000, 100, 10), (10000, 1000, 7), (1000, 2000, 5), (250, 5000, 5), (100, 10000, 5)]
145 | for L, N, nrep in LNs:
146 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=False, symmetrize=False, const=False)
147 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=False, symmetrize=False, const=True)
148 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=True, symmetrize=False, const=False)
149 | benchmark_moments(L=L, N=N, nrep=nrep, xy=False, remove_mean=True, symmetrize=False, const=True)
150 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=False, const=False)
151 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=False, const=True)
152 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=True, const=False)
153 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=False, symmetrize=True, const=True)
154 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=False, const=False)
155 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=False, const=True)
156 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=True, const=False)
157 | benchmark_moments(L=L, N=N, nrep=nrep, xy=True, remove_mean=True, symmetrize=True, const=True)
158 |
159 |
160 | if __name__ == "__main__":
161 | main()
--------------------------------------------------------------------------------
/chainsaw/clustering/regspace.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Created on 26.01.2015
22 |
23 | @author: marscher
24 | '''
25 |
26 | from __future__ import absolute_import
27 |
28 | import warnings
29 |
30 | from . import _regspatial
31 | from .interface import AbstractClustering
32 | from chainsaw.util.annotators import fix_docs
33 | from chainsaw.util.exceptions import NotConvergedWarning
34 |
35 | import numpy as np
36 |
37 |
38 | __all__ = ['RegularSpaceClustering']
39 |
40 |
41 | @fix_docs
42 | class RegularSpaceClustering(AbstractClustering):
43 | r"""Regular space clustering"""
44 |
45 | def __init__(self, dmin, max_centers=1000, metric='euclidean', stride=1, n_jobs=None, skip=0):
46 | """Clusters data objects in such a way, that cluster centers are at least in
47 | distance of dmin to each other according to the given metric.
48 | The assignment of data objects to cluster centers is performed by
49 | Voronoi partioning.
50 |
51 | Regular space clustering [Prinz_2011]_ is very similar to Hartigan's leader
52 | algorithm [Hartigan_1975]_. It consists of two passes through
53 | the data. Initially, the first data point is added to the list of centers.
54 | For every subsequent data point, if it has a greater distance than dmin from
55 | every center, it also becomes a center. In the second pass, a Voronoi
56 | discretization with the computed centers is used to partition the data.
57 |
58 |
59 | Parameters
60 | ----------
61 | dmin : float
62 | minimum distance between all clusters.
63 | metric : str
64 | metric to use during clustering ('euclidean', 'minRMSD')
65 | max_centers : int
66 | if this cutoff is hit during finding the centers,
67 | the algorithm will abort.
68 | n_jobs : int or None, default None
69 | Number of threads to use during assignment of the data.
70 | If None, all available CPUs will be used.
71 |
72 | References
73 | ----------
74 |
75 | .. [Prinz_2011] Prinz J-H, Wu H, Sarich M, Keller B, Senne M, Held M, Chodera JD, Schuette Ch and Noe F. 2011.
76 | Markov models of molecular kinetics: Generation and Validation.
77 | J. Chem. Phys. 134, 174105.
78 | .. [Hartigan_1975] Hartigan J. Clustering algorithms.
79 | New York: Wiley; 1975.
80 |
81 | """
82 | super(RegularSpaceClustering, self).__init__(metric=metric, n_jobs=n_jobs)
83 |
84 | self.set_params(dmin=dmin, metric=metric,
85 | max_centers=max_centers, stride=stride, skip=skip)
86 |
87 | def describe(self):
88 | return "[RegularSpaceClustering dmin=%f, inp_dim=%i]" % (self._dmin, self.data_producer.dimension())
89 |
90 | @property
91 | def dmin(self):
92 | """Minimum distance between cluster centers."""
93 | return self._dmin
94 |
95 | @dmin.setter
96 | def dmin(self, d):
97 | if d < 0:
98 | raise ValueError("d has to be positive")
99 |
100 | self._dmin = float(d)
101 | self._estimated = False
102 |
103 | @property
104 | def max_centers(self):
105 | """
106 | Cutoff during clustering. If reached no more data is taken into account.
107 | You might then consider a larger value or a larger dmin value.
108 | """
109 | return self._max_centers
110 |
111 | @max_centers.setter
112 | def max_centers(self, value):
113 | if value < 0:
114 | raise ValueError("max_centers has to be positive")
115 |
116 | self._max_centers = int(value)
117 | self._estimated = False
118 |
119 | @property
120 | def n_clusters(self):
121 | return self.max_centers
122 |
123 | @n_clusters.setter
124 | def n_clusters(self, val):
125 | self.max_centers = val
126 |
127 | def _estimate(self, iterable, **kwargs):
128 | ########
129 | # Calculate clustercenters:
130 | # 1. choose first datapoint as centroid
131 | # 2. for all X: calc distances to all clustercenters
132 | # 3. add new centroid, if min(distance to all other clustercenters) >= dmin
133 | ########
134 | # temporary list to store cluster centers
135 | clustercenters = []
136 | used_frames = 0
137 | it = iterable.iterator(return_trajindex=False, stride=self.stride,
138 | chunk=self.chunksize, skip=self.skip)
139 | try:
140 | with it:
141 | for X in it:
142 | used_frames += len(X)
143 | _regspatial.cluster(X.astype(np.float32, order='C', copy=False),
144 | clustercenters, self.dmin,
145 | self.metric, self.max_centers)
146 | except RuntimeError:
147 | msg = 'Maximum number of cluster centers reached.' \
148 | ' Consider increasing max_centers or choose' \
149 | ' a larger minimum distance, dmin.'
150 | self._logger.warning(msg)
151 | warnings.warn(msg)
152 | # finished anyway, because we have no more space for clusters. Rest of trajectory has no effect
153 | clustercenters = np.array(clustercenters)
154 | self.update_model_params(clustercenters=clustercenters,
155 | n_cluster=len(clustercenters))
156 | # pass amount of processed data
157 | used_data = used_frames / float(it.n_frames_total()) * 100.0
158 | raise NotConvergedWarning("Used data for centers: %.2f%%" % used_data)
159 |
160 | clustercenters = np.array(clustercenters)
161 | self.update_model_params(clustercenters=clustercenters,
162 | n_clusters=len(clustercenters))
163 |
164 | if len(clustercenters) == 1:
165 | self._logger.warning('Have found only one center according to '
166 | 'minimum distance requirement of %f' % self.dmin)
167 |
168 | return self
169 |
--------------------------------------------------------------------------------
/chainsaw/tests/test_pca.py:
--------------------------------------------------------------------------------
1 |
2 | # This file is part of PyEMMA.
3 | #
4 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
5 | #
6 | # PyEMMA is free software: you can redistribute it and/or modify
7 | # it under the terms of the GNU Lesser General Public License as published by
8 | # the Free Software Foundation, either version 3 of the License, or
9 | # (at your option) any later version.
10 | #
11 | # This program is distributed in the hope that it will be useful,
12 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 | # GNU General Public License for more details.
15 | #
16 | # You should have received a copy of the GNU Lesser General Public License
17 | # along with this program. If not, see .
18 |
19 |
20 | '''
21 | Created on 02.02.2015
22 |
23 | @author: marscher
24 | '''
25 |
26 | from __future__ import absolute_import
27 | import unittest
28 |
29 | import numpy as np
30 |
31 | from chainsaw import pca
32 | from logging import getLogger
33 | from chainsaw.util import types
34 | from six.moves import range
35 |
36 |
37 | logger = getLogger('chainsaw.'+'TestPCA')
38 |
39 |
40 | class TestPCAExtensive(unittest.TestCase):
41 |
42 | @classmethod
43 | def setUpClass(cls):
44 | import msmtools.generation as msmgen
45 |
46 | # set random state, remember old one and set it back in tearDownClass
47 | cls.old_state = np.random.get_state()
48 | np.random.seed(0)
49 |
50 | # generate HMM with two Gaussians
51 | cls.P = np.array([[0.99, 0.01],
52 | [0.01, 0.99]])
53 | cls.T = 10000
54 | means = [np.array([-1,1]), np.array([1,-1])]
55 | widths = [np.array([0.3,2]),np.array([0.3,2])]
56 | # continuous trajectory
57 | cls.X = np.zeros((cls.T, 2))
58 | # hidden trajectory
59 | dtraj = msmgen.generate_traj(cls.P, cls.T)
60 | for t in range(cls.T):
61 | s = dtraj[t]
62 | cls.X[t,0] = widths[s][0] * np.random.randn() + means[s][0]
63 | cls.X[t,1] = widths[s][1] * np.random.randn() + means[s][1]
64 | cls.pca_obj = pca(data = cls.X, dim=1)
65 |
66 | @classmethod
67 | def tearDownClass(cls):
68 | np.random.set_state(cls.old_state)
69 |
70 | def test_chunksize(self):
71 | assert types.is_int(self.pca_obj.chunksize)
72 |
73 | def test_variances(self):
74 | obj = pca(data = self.X)
75 | O = obj.get_output()[0]
76 | vars = np.var(O, axis=0)
77 | refs = obj.eigenvalues
78 | assert np.max(np.abs(vars - refs)) < 0.01
79 |
80 | def test_cumvar(self):
81 | assert len(self.pca_obj.cumvar) == 2
82 | assert np.allclose(self.pca_obj.cumvar[-1], 1.0)
83 |
84 | def test_cov(self):
85 | cov_ref = np.dot(self.X.T, self.X) / float(self.T)
86 | assert(np.all(self.pca_obj.cov.shape == cov_ref.shape))
87 | assert(np.max(self.pca_obj.cov - cov_ref) < 3e-2)
88 |
89 | def test_data_producer(self):
90 | assert self.pca_obj.data_producer is not None
91 |
92 | def test_describe(self):
93 | desc = self.pca_obj.describe()
94 | assert types.is_string(desc) or types.is_list_of_string(desc)
95 |
96 | def test_dimension(self):
97 | assert types.is_int(self.pca_obj.dimension())
98 | # Here:
99 | assert self.pca_obj.dimension() == 1
100 | # Test other variants
101 | obj = pca(data=self.X, dim=-1, var_cutoff=1.0)
102 | assert obj.dimension() == 2
103 | obj = pca(data=self.X, dim=-1, var_cutoff=0.8)
104 | assert obj.dimension() == 1
105 | with self.assertRaises(ValueError): # trying to set both dim and subspace_variance is forbidden
106 | pca(data=self.X, dim=1, var_cutoff=0.8)
107 |
108 | def test_eigenvalues(self):
109 | eval = self.pca_obj.eigenvalues
110 | assert len(eval) == 2
111 |
112 | def test_eigenvectors(self):
113 | evec = self.pca_obj.eigenvectors
114 | assert(np.all(evec.shape == (2,2)))
115 |
116 | def test_get_output(self):
117 | O = self.pca_obj.get_output()
118 | assert types.is_list(O)
119 | assert len(O) == 1
120 | assert types.is_float_matrix(O[0])
121 | assert O[0].shape[0] == self.T
122 | assert O[0].shape[1] == self.pca_obj.dimension()
123 |
124 | def test_in_memory(self):
125 | assert isinstance(self.pca_obj.in_memory, bool)
126 |
127 | def test_iterator(self):
128 | for itraj, chunk in self.pca_obj:
129 | assert types.is_int(itraj)
130 | assert types.is_float_matrix(chunk)
131 | assert chunk.shape[1] == self.pca_obj.dimension()
132 |
133 | def test_map(self):
134 | Y = self.pca_obj.transform(self.X)
135 | assert Y.shape[0] == self.T
136 | assert Y.shape[1] == 1
137 | # test if consistent with get_output
138 | assert np.allclose(Y, self.pca_obj.get_output()[0])
139 |
140 | def test_mean(self):
141 | mean = self.pca_obj.mean
142 | assert len(mean) == 2
143 | assert np.max(mean < 0.5)
144 |
145 | def test_n_frames_total(self):
146 | # map not defined for source
147 | assert self.pca_obj.n_frames_total() == self.T
148 |
149 | def test_number_of_trajectories(self):
150 | # map not defined for source
151 | assert self.pca_obj.number_of_trajectories() == 1
152 |
153 | def test_output_type(self):
154 | assert self.pca_obj.output_type() == np.float32
155 |
156 | def test_trajectory_length(self):
157 | assert self.pca_obj.trajectory_length(0) == self.T
158 | with self.assertRaises(IndexError):
159 | self.pca_obj.trajectory_length(1)
160 |
161 | def test_trajectory_lengths(self):
162 | assert len(self.pca_obj.trajectory_lengths()) == 1
163 | assert self.pca_obj.trajectory_lengths()[0] == self.pca_obj.trajectory_length(0)
164 |
165 | def test_provided_means(self):
166 | data = np.random.random((300, 3))
167 | mean = data.mean(axis=0)
168 | pca_spec_mean = pca(data, mean=mean)
169 | pca_calc_mean = pca(data)
170 |
171 | np.testing.assert_allclose(mean, pca_calc_mean.mean)
172 | np.testing.assert_allclose(mean, pca_spec_mean.mean)
173 |
174 | np.testing.assert_allclose(pca_spec_mean.cov, pca_calc_mean.cov)
175 |
176 | def test_partial_fit(self):
177 | data = [np.random.random((100, 3)), np.random.random((100, 3))]
178 | pca_part = pca()
179 | pca_part.partial_fit(data[0])
180 | pca_part.partial_fit(data[1])
181 |
182 | ref = pca(data)
183 | np.testing.assert_allclose(pca_part.mean, ref.mean)
184 |
185 | np.testing.assert_allclose(pca_part.eigenvalues, ref.eigenvalues)
186 | np.testing.assert_allclose(pca_part.eigenvectors, ref.eigenvectors)
187 |
188 | if __name__ == "__main__":
189 | unittest.main()
190 |
--------------------------------------------------------------------------------
/chainsaw/data/numpy_filereader.py:
--------------------------------------------------------------------------------
1 | # This file is part of PyEMMA.
2 | #
3 | # Copyright (c) 2015, 2014 Computational Molecular Biology Group, Freie Universitaet Berlin (GER)
4 | #
5 | # PyEMMA is free software: you can redistribute it and/or modify
6 | # it under the terms of the GNU Lesser General Public License as published by
7 | # the Free Software Foundation, either version 3 of the License, or
8 | # (at your option) any later version.
9 | #
10 | # This program is distributed in the hope that it will be useful,
11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 | # GNU General Public License for more details.
14 | #
15 | # You should have received a copy of the GNU Lesser General Public License
16 | # along with this program. If not, see .
17 | '''
18 | Created on 07.04.2015
19 |
20 | @author: marscher
21 | '''
22 |
23 | from __future__ import absolute_import
24 |
25 | import functools
26 |
27 | import numpy as np
28 |
29 | from chainsaw.data._base.datasource import DataSourceIterator, DataSource
30 | from chainsaw.data.util.traj_info_cache import TrajInfo
31 | from chainsaw.util.annotators import fix_docs
32 |
33 | from chainsaw.data.util.fileformat_registry import FileFormatRegistry
34 |
35 |
36 | @fix_docs
37 | @FileFormatRegistry.register('.npy')
38 | class NumPyFileReader(DataSource):
39 |
40 | """reads NumPy files in chunks. Supports .npy files
41 |
42 | Parameters
43 | ----------
44 | filenames : str or list of strings
45 |
46 | chunksize : int
47 | how many rows are read at once
48 |
49 | mmap_mode : str (optional), default='r'
50 | binary NumPy arrays are being memory mapped using this flag.
51 | """
52 |
53 | def __init__(self, filenames, chunksize=1000, mmap_mode='r', **kw):
54 | super(NumPyFileReader, self).__init__(chunksize=chunksize)
55 | self._is_reader = True
56 |
57 | if not isinstance(filenames, (list, tuple)):
58 | filenames = [filenames]
59 |
60 | self.mmap_mode = mmap_mode
61 | self.filenames = filenames
62 |
63 | def _create_iterator(self, skip=0, chunk=0, stride=1, return_trajindex=False, cols=None):
64 | return NPYIterator(self, skip=skip, chunk=chunk, stride=stride,
65 | return_trajindex=return_trajindex, cols=cols)
66 |
67 | def describe(self):
68 | shapes = [(l, self.ndim) for l
69 | in self._lengths]
70 | return "[NumpyFileReader arrays with shapes: %s]" % shapes
71 |
72 | def _reshape(self, array):
73 | """
74 | checks shapes, eg convert them (2d), raise if not possible
75 | after checks passed, set self._array and return it.
76 | """
77 |
78 | if array.ndim == 1:
79 | array = np.atleast_2d(array).T
80 | elif array.ndim == 2:
81 | pass
82 | else:
83 | shape = array.shape
84 | # hold first dimension, multiply the rest
85 | shape_2d = (shape[0],
86 | functools.reduce(lambda x, y: x * y, shape[1:]))
87 | array = np.reshape(array, shape_2d)
88 | return array
89 |
90 | def _load_file(self, itraj):
91 | filename = self._filenames[itraj]
92 | #self._logger.debug("opening file %s" % filename)
93 |
94 | x = np.load(filename, mmap_mode=self.mmap_mode)
95 | arr = self._reshape(x)
96 | return arr
97 |
98 | def _get_traj_info(self, filename):
99 | idx = self.filenames.index(filename)
100 | array = self._load_file(idx)
101 | length, ndim = np.shape(array)
102 |
103 | return TrajInfo(ndim, length)
104 |
105 |
106 | class NPYIterator(DataSourceIterator):
107 |
108 | def __init__(self, data_source, skip=0, chunk=0, stride=1, return_trajindex=False, cols=False):
109 | super(NPYIterator, self).__init__(data_source=data_source, skip=skip,
110 | chunk=chunk, stride=stride,
111 | return_trajindex=return_trajindex,
112 | cols=cols)
113 |
114 | self._last_itraj = -1
115 |
116 | def reset(self):
117 | DataSourceIterator.reset(self)
118 | self._last_itraj = -1
119 |
120 | def close(self):
121 | self._close_filehandle()
122 |
123 | def _close_filehandle(self):
124 | if not hasattr(self, '_array') or self._array is None:
125 | return
126 | del self._array
127 | self._array = None
128 |
129 | def _open_filehandle(self):
130 | self._array = self._data_source._load_file(self._itraj)
131 |
132 | def _next_chunk(self):
133 | if self._itraj >= self._data_source.ntraj:
134 | self.close()
135 | raise StopIteration()
136 |
137 | if self._itraj != self._last_itraj:
138 | self._close_filehandle()
139 | self._open_filehandle()
140 |
141 | traj_len = len(self._array)
142 | traj = self._array
143 |
144 | # skip only if complete trajectory mode or first chunk
145 | skip = self.skip if self.chunksize == 0 or self._t == 0 else 0
146 |
147 | # if stride by dict, update traj length accordingly
148 | if not self.uniform_stride:
149 | traj_len = self.ra_trajectory_length(self._itraj)
150 |
151 | # complete trajectory mode
152 | if self.chunksize == 0:
153 | if not self.uniform_stride:
154 | X = traj[self.ra_indices_for_traj(self._itraj)]
155 | self._itraj += 1
156 |
157 | # skip the trajs that are not in the stride dict
158 | while self._itraj < self.number_of_trajectories() \
159 | and (self._itraj not in self.traj_keys):
160 | self._itraj += 1
161 |
162 | else:
163 | X = traj[skip::self.stride]
164 | self._itraj += 1
165 |
166 | return X
167 |
168 | # chunked mode
169 | else:
170 | if not self.uniform_stride:
171 | X = traj[self.ra_indices_for_traj(self._itraj)[self._t:min(self._t + self.chunksize, traj_len)]]
172 | upper_bound = min(self._t + self.chunksize, traj_len)
173 | else:
174 | upper_bound = min(skip + self._t + self.chunksize * self.stride, traj_len)
175 | slice_x = slice(skip + self._t, upper_bound, self.stride)
176 | X = traj[slice_x]
177 |
178 | # set new time position
179 | self._t = upper_bound
180 |
181 | if self._t >= traj_len:
182 | self._itraj += 1
183 | self._t = 0
184 |
185 | # if we have a dictionary, skip trajectories that are not in the key set
186 | while not self.uniform_stride and self._itraj < self.number_of_trajectories() \
187 | and (self._itraj not in self.traj_keys):
188 | self._itraj += 1
189 |
190 | # if time index scope ran out of len of current trajectory, open next file.
191 | if self._itraj <= self.number_of_trajectories() - 1:
192 | self._array = self._data_source._load_file(self._itraj)
193 |
194 | return X
195 |
--------------------------------------------------------------------------------