├── .dockerignore ├── .github └── workflows │ └── build-docs.yaml ├── .gitignore ├── .readthedocs.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── conftest.py ├── dev └── requirements.txt ├── docs ├── Makefile ├── make.bat └── source │ ├── cli.rst │ ├── cli.treeflow_ml.rst │ ├── cli.treeflow_vi.rst │ ├── conf.py │ ├── index.rst │ ├── installation.md │ ├── model-definition.md │ ├── modules.rst │ ├── rates-and-dates.nblink │ ├── treeflow.acceleration.bito.beagle.rst │ ├── treeflow.acceleration.bito.instance.rst │ ├── treeflow.acceleration.bito.ratio_transform.rst │ ├── treeflow.acceleration.bito.rst │ ├── treeflow.acceleration.rst │ ├── treeflow.bijectors.fixed_topology_bijector.rst │ ├── treeflow.bijectors.highway_flow.rst │ ├── treeflow.bijectors.highway_flow_node_bijector.rst │ ├── treeflow.bijectors.node_height_ratio_bijector.rst │ ├── treeflow.bijectors.preorder_node_bijector.rst │ ├── treeflow.bijectors.rst │ ├── treeflow.bijectors.tree_ratio_bijector.rst │ ├── treeflow.cli.benchmark.rst │ ├── treeflow.cli.inference_common.rst │ ├── treeflow.cli.ml.rst │ ├── treeflow.cli.rst │ ├── treeflow.cli.vi.rst │ ├── treeflow.debug.minimize_eager.rst │ ├── treeflow.debug.nonfinite_convergence_criterion.rst │ ├── treeflow.debug.rst │ ├── treeflow.distributions.discrete.rst │ ├── treeflow.distributions.discrete_parameter_mixture.rst │ ├── treeflow.distributions.discretized.rst │ ├── treeflow.distributions.leaf_ctmc.rst │ ├── treeflow.distributions.markov_chain.linear_gaussian.rst │ ├── treeflow.distributions.markov_chain.postorder.rst │ ├── treeflow.distributions.markov_chain.rst │ ├── treeflow.distributions.rst │ ├── treeflow.distributions.sample_weighted.rst │ ├── treeflow.distributions.tree.base_tree_distribution.rst │ ├── treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling.rst │ ├── treeflow.distributions.tree.birthdeath.rst │ ├── treeflow.distributions.tree.birthdeath.yule.rst │ ├── treeflow.distributions.tree.coalescent.constant_coalescent.rst │ ├── treeflow.distributions.tree.coalescent.rst │ ├── treeflow.distributions.tree.rooted_tree_distribution.rst │ ├── treeflow.distributions.tree.rst │ ├── treeflow.evolution.calibration.calibration.rst │ ├── treeflow.evolution.calibration.mrca.rst │ ├── treeflow.evolution.calibration.rst │ ├── treeflow.evolution.rst │ ├── treeflow.evolution.seqio.rst │ ├── treeflow.evolution.substitution.base_substitution_model.rst │ ├── treeflow.evolution.substitution.eigendecomposition.rst │ ├── treeflow.evolution.substitution.nucleotide.alphabet.rst │ ├── treeflow.evolution.substitution.nucleotide.gtr.rst │ ├── treeflow.evolution.substitution.nucleotide.hky.rst │ ├── treeflow.evolution.substitution.nucleotide.jc.rst │ ├── treeflow.evolution.substitution.nucleotide.rst │ ├── treeflow.evolution.substitution.probabilities.rst │ ├── treeflow.evolution.substitution.rst │ ├── treeflow.evolution.substitution.util.rst │ ├── treeflow.model.approximation.cascading_flows.rst │ ├── treeflow.model.approximation.iaf.rst │ ├── treeflow.model.approximation.mean_field.rst │ ├── treeflow.model.approximation.rst │ ├── treeflow.model.event_shape_bijector.rst │ ├── treeflow.model.io.rst │ ├── treeflow.model.ml.rst │ ├── treeflow.model.phylo_model.rst │ ├── treeflow.model.rst │ ├── treeflow.model.structured_approximation.rst │ ├── treeflow.rst │ ├── treeflow.tf_util.attrs.rst │ ├── treeflow.tf_util.dtype_util.rst │ ├── treeflow.tf_util.linear_operator_upper_triangular.rst │ ├── treeflow.tf_util.rst │ ├── treeflow.tf_util.vectorize.rst │ ├── treeflow.traversal.anchor_heights.rst │ ├── treeflow.traversal.phylo_likelihood.rst │ ├── treeflow.traversal.postorder.rst │ ├── treeflow.traversal.preorder.rst │ ├── treeflow.traversal.ratio_transform.rst │ ├── treeflow.traversal.rst │ ├── treeflow.tree.base_tree.rst │ ├── treeflow.tree.io.rst │ ├── treeflow.tree.rooted.base_rooted_tree.rst │ ├── treeflow.tree.rooted.numpy_rooted_tree.rst │ ├── treeflow.tree.rooted.rst │ ├── treeflow.tree.rooted.tensorflow_rooted_tree.rst │ ├── treeflow.tree.rst │ ├── treeflow.tree.taxon_set.rst │ ├── treeflow.tree.topology.base_tree_topology.rst │ ├── treeflow.tree.topology.numpy_topology_operations.rst │ ├── treeflow.tree.topology.numpy_tree_topology.rst │ ├── treeflow.tree.topology.rst │ ├── treeflow.tree.topology.tensorflow_tree_topology.rst │ ├── treeflow.tree.unrooted.base_unrooted_tree.rst │ ├── treeflow.tree.unrooted.numpy_unrooted_tree.rst │ ├── treeflow.tree.unrooted.rst │ ├── treeflow.tree.unrooted.tensorflow_unrooted_tree.rst │ ├── treeflow.vi.convergence_criteria.nonfinite.rst │ ├── treeflow.vi.convergence_criteria.rst │ ├── treeflow.vi.fixed_topology_advi.rst │ ├── treeflow.vi.marginal_likelihood.rst │ ├── treeflow.vi.optimizers.robust_optimizer.rst │ ├── treeflow.vi.optimizers.rst │ ├── treeflow.vi.progress_bar.rst │ ├── treeflow.vi.rst │ ├── treeflow.vi.util.rst │ └── tutorials.rst ├── examples ├── README.md ├── carnivores.ipynb ├── demo-data │ ├── YFV.fasta │ ├── YFV.newick │ ├── YFV.nex │ ├── carnivores.fasta │ ├── carnivores.newick │ ├── h3n2.fasta │ └── h3n2.nwk ├── h3n2-model.yaml ├── h3n2-vi.sh ├── rates-and-dates-model.yaml └── rates-and-dates.ipynb ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── setup.cfg ├── setup.py ├── test ├── acceleration │ └── bito │ │ ├── test_beagle.py │ │ ├── test_bito_instance.py │ │ └── test_bito_ratio_transform.py ├── bijectors │ ├── test_highway_flow_node_bijector.py │ ├── test_node_height_ratio_bijector.py │ ├── test_preorder_node_bijector.py │ └── test_tree_ratio_bijector.py ├── cli │ ├── test_benchmark.py │ ├── test_ml_cli.py │ └── test_vi.py ├── data │ ├── beast-test-case.fasta │ ├── beast-test-case.nwk │ ├── hello.fasta │ ├── hello.nwk │ ├── model.yaml │ ├── test-beast-analysis.xml │ ├── tree-sim.newick │ ├── wnv.fasta │ ├── wnv.nwk │ └── yule-model.yaml ├── debug_util │ ├── test_minimize_eager.py │ └── test_nonfinite_convergence_criterion.py ├── distributions │ ├── markov_chain │ │ ├── test_linear_gaussian.py │ │ └── test_postorder_node_markov_chain.py │ ├── test_discrete_parameter_mixture.py │ ├── test_discretized.py │ ├── test_leaf_ctmc.py │ ├── test_sample_weighted.py │ └── tree │ │ ├── birthdeath │ │ ├── test_birth_death_contemporary_sampling.py │ │ └── test_yule.py │ │ ├── coalescent │ │ └── test_constant_coalescent.py │ │ └── test_rooted_tree_distribution.py ├── evolution │ ├── calibration │ │ ├── test_calibration.py │ │ └── test_mrca.py │ ├── substitution │ │ ├── nucleotide │ │ │ ├── test_gtr.py │ │ │ ├── test_hky.py │ │ │ └── test_jc.py │ │ ├── test_eigendecomposition.py │ │ └── test_probabilities.py │ └── test_seqio.py ├── fixtures │ ├── __init__.py │ ├── cli_fixtures.py │ ├── data_fixtures.py │ ├── ratio_fixtures.py │ ├── substitution_fixtures.py │ └── tree_fixtures.py ├── helpers │ └── treeflow_test_helpers │ │ ├── __init__.py │ │ ├── optimization_helpers.py │ │ ├── ratio_helpers.py │ │ ├── substitution_helpers.py │ │ └── tree_helpers.py ├── model │ ├── approximation │ │ └── test_cascading_flows.py │ ├── test_approximation.py │ ├── test_event_space_bijector.py │ ├── test_ml.py │ ├── test_model_io.py │ └── test_phylo_model.py ├── tf_util │ └── test_vectorize.py ├── traversal │ ├── test_get_anchor_heights.py │ ├── test_phylo_likelihood.py │ ├── test_postorder.py │ ├── test_preorder.py │ └── test_ratio_transform.py ├── tree │ ├── rooted │ │ ├── test_numpy_rooted_tree.py │ │ └── test_tensorflow_rooted_tree.py │ ├── test_io.py │ ├── test_taxon_set.py │ └── topology │ │ ├── test_numpy_topology_operations.py │ │ └── test_tensorflow_topology.py └── vi │ ├── optimizers │ └── test_robust_optimizer.py │ ├── test_fixed_topology_advi.py │ ├── test_marginal_likelihood.py │ └── test_progress_bar.py └── treeflow ├── __init__.py ├── acceleration ├── __init__.py └── bito │ ├── __init__.py │ ├── beagle.py │ ├── instance.py │ └── ratio_transform.py ├── bijectors ├── __init__.py ├── fixed_topology_bijector.py ├── highway_flow.py ├── highway_flow_node_bijector.py ├── node_height_ratio_bijector.py ├── preorder_node_bijector.py └── tree_ratio_bijector.py ├── cli ├── __init__.py ├── benchmark.py ├── inference_common.py ├── ml.py └── vi.py ├── debug ├── __init__.py ├── minimize_eager.py └── nonfinite_convergence_criterion.py ├── distributions ├── __init__.py ├── discrete.py ├── discrete_parameter_mixture.py ├── discretized.py ├── leaf_ctmc.py ├── markov_chain │ ├── __init__.py │ ├── linear_gaussian.py │ └── postorder.py ├── sample_weighted.py └── tree │ ├── __init__.py │ ├── base_tree_distribution.py │ ├── birthdeath │ ├── __init__.py │ ├── birth_death_contemporary_sampling.py │ └── yule.py │ ├── coalescent │ ├── __init__.py │ └── constant_coalescent.py │ └── rooted_tree_distribution.py ├── evolution ├── __init__.py ├── calibration │ ├── __init__.py │ ├── calibration.py │ └── mrca.py ├── seqio.py └── substitution │ ├── __init__.py │ ├── base_substitution_model.py │ ├── eigendecomposition.py │ ├── nucleotide │ ├── __init__.py │ ├── alphabet.py │ ├── gtr.py │ ├── hky.py │ └── jc.py │ ├── probabilities.py │ └── util.py ├── model ├── __init__.py ├── approximation │ ├── __init__.py │ ├── cascading_flows.py │ ├── iaf.py │ └── mean_field.py ├── event_shape_bijector.py ├── io.py ├── ml.py ├── phylo_model.py └── structured_approximation.py ├── tf_util ├── __init__.py ├── attrs.py ├── dtype_util.py ├── linear_operator_upper_triangular.py └── vectorize.py ├── traversal ├── __init__.py ├── anchor_heights.py ├── phylo_likelihood.py ├── postorder.py ├── preorder.py └── ratio_transform.py ├── tree ├── __init__.py ├── base_tree.py ├── io.py ├── rooted │ ├── __init__.py │ ├── base_rooted_tree.py │ ├── numpy_rooted_tree.py │ └── tensorflow_rooted_tree.py ├── taxon_set.py ├── topology │ ├── __init__.py │ ├── base_tree_topology.py │ ├── numpy_topology_operations.py │ ├── numpy_tree_topology.py │ └── tensorflow_tree_topology.py └── unrooted │ ├── __init__.py │ ├── base_unrooted_tree.py │ ├── numpy_unrooted_tree.py │ └── tensorflow_unrooted_tree.py └── vi ├── __init__.py ├── convergence_criteria ├── __init__.py └── nonfinite.py ├── fixed_topology_advi.py ├── marginal_likelihood.py ├── optimizers ├── __init__.py └── robust_optimizer.py ├── progress_bar.py └── util.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # Include any files or directories that you don't want to be copied to your 2 | # container here (e.g., local build artifacts, temporary files, etc.). 3 | # 4 | # For more help, visit the .dockerignore file reference guide at 5 | # https://docs.docker.com/go/build-context-dockerignore/ 6 | 7 | **/.DS_Store 8 | **/__pycache__ 9 | **/.venv 10 | **/.classpath 11 | **/.dockerignore 12 | **/.env 13 | **/.git 14 | **/.gitignore 15 | **/.project 16 | **/.settings 17 | **/.toolstarget 18 | **/.vs 19 | **/.vscode 20 | **/*.*proj.user 21 | **/*.dbmdl 22 | **/*.jfm 23 | **/bin 24 | **/charts 25 | **/docker-compose* 26 | **/compose* 27 | **/Dockerfile* 28 | **/node_modules 29 | **/npm-debug.log 30 | **/obj 31 | **/secrets.dev.yaml 32 | **/values.dev.yaml 33 | LICENSE 34 | README.md 35 | -------------------------------------------------------------------------------- /.github/workflows/build-docs.yaml: -------------------------------------------------------------------------------- 1 | name: Build docs 2 | on: pull_request 3 | jobs: 4 | docs: 5 | runs-on: ubuntu-latest 6 | steps: 7 | - uses: actions/checkout@v3 8 | - uses: actions/setup-python@v4 9 | with: 10 | python-version: 3.8 11 | - name: Install dependencies 12 | run: | 13 | pip install -r dev/requirements.txt 14 | - name: Install package 15 | run: | 16 | pip install . 17 | - name: Install pandoc 18 | run: | 19 | sudo apt-get install pandoc 20 | - name: Sphinx build 21 | run: | 22 | cd docs 23 | make html -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | .vscode 108 | /scratch 109 | 110 | examples/demo-out/ -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.8" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/source/conf.py 17 | 18 | # Optionally declare the Python requirements required to build your docs 19 | python: 20 | install: 21 | - requirements: dev/requirements.txt 22 | - method: pip 23 | path: . -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | 3 | # Comments are provided throughout this file to help you get started. 4 | # If you need more help, visit the Dockerfile reference guide at 5 | # https://docs.docker.com/go/dockerfile-reference/ 6 | 7 | # Want to help us make this template better? Share your feedback here: https://forms.gle/ybq9Krt8jtBL3iCk7 8 | 9 | ARG PYTHON_VERSION=3.9.12 10 | FROM python:${PYTHON_VERSION}-slim as base 11 | 12 | # Prevents Python from writing pyc files. 13 | ENV PYTHONDONTWRITEBYTECODE=1 14 | 15 | # Keeps Python from buffering stdout and stderr to avoid situations where 16 | # the application crashes without emitting any logs due to buffering. 17 | ENV PYTHONUNBUFFERED=1 18 | 19 | WORKDIR /app 20 | 21 | # Create a non-privileged user that the app will run under. 22 | # See https://docs.docker.com/go/dockerfile-user-best-practices/ 23 | ARG UID=10001 24 | RUN adduser \ 25 | --disabled-password \ 26 | --gecos "" \ 27 | # --home "/nonexistent" \ 28 | --shell "/sbin/nologin" \ 29 | # --no-create-home \ 30 | --uid "${UID}" \ 31 | appuser 32 | 33 | # Download dependencies as a separate step to take advantage of Docker's caching. 34 | # Leverage a cache mount to /root/.cache/pip to speed up subsequent builds. 35 | # Leverage a bind mount to requirements.txt to avoid having to copy them into 36 | # into this layer. 37 | RUN --mount=type=cache,target=/root/.cache/pip \ 38 | --mount=type=bind,source=requirements.txt,target=requirements.txt \ 39 | python -m pip install -r requirements.txt 40 | 41 | COPY . . 42 | RUN python -m pip install . 43 | 44 | # Switch to the non-privileged user to run the application. 45 | USER appuser 46 | 47 | # Copy the source code into the container. 48 | 49 | # Expose the port that the application listens on. 50 | EXPOSE 8888 51 | 52 | 53 | # Run the application. 54 | CMD jupyter lab --ip="0.0.0.0" --no-browser 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TreeFlow 2 | 3 | TreeFlow is a library for phylogenetic modelling and inference based on [TensorFlow Probability](https://www.tensorflow.org/probability) (TFP). 4 | 5 | It also includes [command line interfaces](https://treeflow.readthedocs.io/en/latest/cli.html) for fixed-topology phylogenetic inference. 6 | 7 | ## Documentation 8 | 9 | [Online manual: tutorials, API documentation, CLI description](https://treeflow.readthedocs.io/en/latest/) 10 | 11 | ## Installation and getting started 12 | 13 | See [installation instructions](https://treeflow.readthedocs.io/en/latest/installation.html) 14 | * (Optional) Build and install [`bito`](https://github.com/phylovi/bito) for accelerated computations - not used in CLI 15 | 16 | ## Citation 17 | 18 | If you want to cite or read about TreeFlow, please see the paper: 19 | 20 | Christiaan Swanepoel, Mathieu Fourment, Xiang Ji, Hassan Nasif, Marc A Suchard, Frederick A Matsen IV, Alexei Drummond. ["TreeFlow: probabilistic programming and automatic differentiation for phylogenetics"](https://arxiv.org/abs/2211.05220). arXiv preprint arXiv:2211.05220 (2022). 21 | 22 | ## Unit tests 23 | 24 | 1. `pip install -r dev/requirements.txt` 25 | 2. `pytest` 26 | 27 | Note tests for acceleration and the benchmark CLI will fail if the extra dependencies for those components are not installed (and `bito` cannot yet be installed with `pip`) -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pathlib 4 | import pytest 5 | import tensorflow as tf 6 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 7 | from functools import partial 8 | 9 | 10 | conftest_dir = pathlib.Path(__file__).parents[0] 11 | sys.path.append(str(conftest_dir / "test" / "helpers")) 12 | sys.path.append(str(conftest_dir / "test" / "fixtures")) 13 | 14 | 15 | @pytest.fixture 16 | def tensor_constant(): 17 | return partial(tf.constant, dtype=DEFAULT_FLOAT_DTYPE_TF) 18 | 19 | 20 | pytest_plugins = [ 21 | "tree_fixtures", 22 | "data_fixtures", 23 | "ratio_fixtures", 24 | "substitution_fixtures", 25 | "cli_fixtures", 26 | ] 27 | 28 | if os.getenv("_PYTEST_RAISE", "0") != "0": 29 | # Stop pytest catching exceptions in debug run configuration 30 | @pytest.hookimpl(tryfirst=True) 31 | def pytest_exception_interact(call): 32 | raise call.excinfo.value 33 | 34 | @pytest.hookimpl(tryfirst=True) 35 | def pytest_internalerror(excinfo): 36 | raise excinfo.value 37 | 38 | 39 | @pytest.fixture 40 | def test_data_dir(): 41 | return pathlib.Path("test") / "data" 42 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | buildapi: 23 | sphinx-apidoc -fMeT -d 1 ../treeflow -o source 24 | @echo "Auto-generation of API documentation finished. " \ 25 | "The generated files are in 'source/'" -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/cli.rst: -------------------------------------------------------------------------------- 1 | Command line interfaces 2 | ================================ 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | 7 | cli.treeflow_vi 8 | cli.treeflow_ml 9 | -------------------------------------------------------------------------------- /docs/source/cli.treeflow_ml.rst: -------------------------------------------------------------------------------- 1 | Maximum likelihood CLI 2 | ================================== 3 | 4 | .. click:: treeflow.cli.ml:treeflow_ml 5 | :prog: treeflow_ml 6 | :nested: full -------------------------------------------------------------------------------- /docs/source/cli.treeflow_vi.rst: -------------------------------------------------------------------------------- 1 | Variational Inference CLI 2 | ================================== 3 | 4 | .. click:: treeflow.cli.vi:treeflow_vi 5 | :prog: treeflow_vi 6 | :nested: full -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | 9 | from silence_tensorflow import silence_tensorflow 10 | 11 | silence_tensorflow() 12 | 13 | project = "TreeFlow" 14 | copyright = "2023, Christiaan Swanepoel" 15 | author = "Christiaan Swanepoel" 16 | release = "0.0.1" 17 | 18 | # -- General configuration --------------------------------------------------- 19 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 20 | 21 | extensions = [ 22 | "sphinx.ext.autodoc", 23 | "sphinx.ext.coverage", 24 | "sphinx_rtd_theme", 25 | "sphinxcontrib.napoleon", 26 | "myst_parser", 27 | "sphinx_click", 28 | "nbsphinx", 29 | "nbsphinx_link", 30 | ] 31 | 32 | templates_path = ["_templates"] 33 | exclude_patterns = [] 34 | autodoc_default_options = { 35 | "members": True, 36 | "inherited-members": True, 37 | "undoc-members": True, 38 | } 39 | autodoc_inherit_docstrings = True 40 | autodoc_member_order = "bysource" 41 | 42 | # -- Options for HTML output ------------------------------------------------- 43 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 44 | 45 | html_theme = "sphinx_rtd_theme" 46 | html_static_path = ["_static"] 47 | 48 | nbsphinx_prolog = """ 49 | .. raw:: html 50 | 51 | 56 | """ 57 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. TreeFlow documentation master file 2 | 3 | TreeFlow: automatic differentiation and probabilistic modelling with phylogenetic trees 4 | ======================================================================================= 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: Contents: 9 | 10 | installation 11 | tutorials 12 | cli 13 | model-definition 14 | modules 15 | 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | 1. Set up a Python environment (e.g. with `virtualenv` or `conda`) with Python 3.7 or later 4 | 2. Clone and navigate to the repository: `git clone https://github.com/christiaanjs/treeflow.git` then `cd treeflow` 5 | 3. Install the Python package: `pip install --user .` 6 | 4. Run TreeFlow 7 | * To run Jupyter Lab 8 | 1. `jupyter lab` 9 | 2. Follow the link that appears to Jupyter (`http://127.0.0.1:8888/lab?token=...`) 10 | * Run one of [TreeFlow's command line applications](cli) 11 | 12 | 13 | ## Docker 14 | 15 | 1. Install Docker following the [official instructions](https://docs.docker.com/engine/install/) 16 | 2. Clone and navigate to the repository: `git clone https://github.com/christiaanjs/treeflow.git` then `cd treeflow` 17 | 3. Build the container: `docker build -t treeflow .` 18 | 4. Run TreeFlow 19 | * To run Jupyter Lab: 20 | 1. `docker run treeflow -p 8888:8888` 21 | * To use a different port (e.g. 8999) use `docker run -p 8999:8888 treeflow` 22 | * If you need to access a local directory (e.g. for input and output files) mount it into the docker image: `docker run treeflow -v /home/dev/repo/data:/app/data` to mount the directory `/home/dev/repo/data` to the `data` directory in the notebook (both must be absolute paths, and `/app` is the working directory in the container) 23 | * If you need to save output to the mounted directory you'll need to give the Docker user (ID 10001) permissions: `mkdir /home/dev/repo/data/out` then `sudo chown -R 10001:10001 /home/dev/repo/data/out` 24 | 25 | 2. Follow the link that appears to Jupyter (`http://127.0.0.1:8888/lab?token=...`), changing the port if necessary 26 | 3. To stop the process, use `docker kill {container}` (you can use `docker ps` to lookup the ID) 27 | * On Linux `docker ps | grep treeflow | awk '{print $1}' | xargs docker kill` will stop all TreeFlow containers 28 | * To run one of [TreeFlow's command line applications](cli): 29 | * `docker run treeflow {command}` 30 | * You may want to mount a data directory for input/output e.g. `docker run -v /home/dev/repo/data:/app/data treeflow_vi -i data/alignment.fasta -t data/topology.nwk --tree-samples-output data/tree-results.nexus` 31 | 32 | -------------------------------------------------------------------------------- /docs/source/model-definition.md: -------------------------------------------------------------------------------- 1 | # Model definition format 2 | 3 | For an example model definition file, see [`examples/h3n2-model.yaml`](https://github.com/christiaanjs/treeflow/blob/master/examples/h3n2-model.yaml). 4 | 5 | TreeFlow's command line interfaces use a YAML model definition format. Each model definition file has four sections: 6 | 7 | ```yaml 8 | tree: 9 | ... 10 | clock: 11 | ... 12 | site: 13 | ... 14 | substitution: 15 | ... 16 | ``` 17 | 18 | Each section takes a YAML mapping from the name for the selection of that model component to its parameters. 19 | 20 | Each parameter can be a fixed value, or a prior distribution if the parameter is to be estimated. 21 | 22 | Prior distributions are specified as mappings from a distribution name to parameters. 23 | 24 | For example: 25 | 26 | ```yaml 27 | substitution: 28 | hky: 29 | kappa: 30 | 2.0 31 | frequencies: 32 | dirichlet: 33 | concentration: [2.0, 2.0, 2.0, 2.0] 34 | 35 | ``` 36 | 37 | Documentation on all the options is still to come, see the source code at [`treeflow.model.phylo_model.phylo_model_to_joint_distribution`](https://github.com/christiaanjs/treeflow/blob/master/treeflow/model/phylo_model.py) for reference for now. 38 | 39 | ## Prior distributions 40 | 41 | For parameters, see the corresponding TensorFlow Probability distribution. 42 | 43 | * `normal` 44 | * `lognormal` 45 | * `gamma` 46 | * `exponential` 47 | * `beta` 48 | * `dirichlet` -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | API documentation 2 | ========================== 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | 7 | treeflow 8 | -------------------------------------------------------------------------------- /docs/source/rates-and-dates.nblink: -------------------------------------------------------------------------------- 1 | { 2 | "path": "../../examples/rates-and-dates.ipynb" 3 | } -------------------------------------------------------------------------------- /docs/source/treeflow.acceleration.bito.beagle.rst: -------------------------------------------------------------------------------- 1 | treeflow.acceleration.bito.beagle module 2 | ======================================== 3 | 4 | .. automodule:: treeflow.acceleration.bito.beagle 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.acceleration.bito.instance.rst: -------------------------------------------------------------------------------- 1 | treeflow.acceleration.bito.instance module 2 | ========================================== 3 | 4 | .. automodule:: treeflow.acceleration.bito.instance 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.acceleration.bito.ratio_transform.rst: -------------------------------------------------------------------------------- 1 | treeflow.acceleration.bito.ratio\_transform module 2 | ================================================== 3 | 4 | .. automodule:: treeflow.acceleration.bito.ratio_transform 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.acceleration.bito.rst: -------------------------------------------------------------------------------- 1 | treeflow.acceleration.bito package 2 | ================================== 3 | 4 | .. automodule:: treeflow.acceleration.bito 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.acceleration.bito.beagle 16 | treeflow.acceleration.bito.instance 17 | treeflow.acceleration.bito.ratio_transform 18 | -------------------------------------------------------------------------------- /docs/source/treeflow.acceleration.rst: -------------------------------------------------------------------------------- 1 | treeflow.acceleration package 2 | ============================= 3 | 4 | .. automodule:: treeflow.acceleration 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.acceleration.bito 16 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.fixed_topology_bijector.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors.fixed\_topology\_bijector module 2 | =================================================== 3 | 4 | .. automodule:: treeflow.bijectors.fixed_topology_bijector 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.highway_flow.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors.highway\_flow module 2 | ======================================= 3 | 4 | .. automodule:: treeflow.bijectors.highway_flow 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.highway_flow_node_bijector.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors.highway\_flow\_node\_bijector module 2 | ======================================================= 3 | 4 | .. automodule:: treeflow.bijectors.highway_flow_node_bijector 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.node_height_ratio_bijector.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors.node\_height\_ratio\_bijector module 2 | ======================================================= 3 | 4 | .. automodule:: treeflow.bijectors.node_height_ratio_bijector 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.preorder_node_bijector.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors.preorder\_node\_bijector module 2 | ================================================== 3 | 4 | .. automodule:: treeflow.bijectors.preorder_node_bijector 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors package 2 | ========================== 3 | 4 | .. automodule:: treeflow.bijectors 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.bijectors.fixed_topology_bijector 16 | treeflow.bijectors.highway_flow 17 | treeflow.bijectors.highway_flow_node_bijector 18 | treeflow.bijectors.node_height_ratio_bijector 19 | treeflow.bijectors.preorder_node_bijector 20 | treeflow.bijectors.tree_ratio_bijector 21 | -------------------------------------------------------------------------------- /docs/source/treeflow.bijectors.tree_ratio_bijector.rst: -------------------------------------------------------------------------------- 1 | treeflow.bijectors.tree\_ratio\_bijector module 2 | =============================================== 3 | 4 | .. automodule:: treeflow.bijectors.tree_ratio_bijector 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.cli.benchmark.rst: -------------------------------------------------------------------------------- 1 | treeflow.cli.benchmark module 2 | ============================= 3 | 4 | .. automodule:: treeflow.cli.benchmark 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.cli.inference_common.rst: -------------------------------------------------------------------------------- 1 | treeflow.cli.inference\_common module 2 | ===================================== 3 | 4 | .. automodule:: treeflow.cli.inference_common 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.cli.ml.rst: -------------------------------------------------------------------------------- 1 | treeflow.cli.ml module 2 | ====================== 3 | 4 | .. automodule:: treeflow.cli.ml 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.cli.rst: -------------------------------------------------------------------------------- 1 | treeflow.cli package 2 | ==================== 3 | 4 | .. automodule:: treeflow.cli 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.cli.benchmark 16 | treeflow.cli.inference_common 17 | treeflow.cli.ml 18 | treeflow.cli.vi 19 | -------------------------------------------------------------------------------- /docs/source/treeflow.cli.vi.rst: -------------------------------------------------------------------------------- 1 | treeflow.cli.vi module 2 | ====================== 3 | 4 | .. automodule:: treeflow.cli.vi 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.debug.minimize_eager.rst: -------------------------------------------------------------------------------- 1 | treeflow.debug.minimize\_eager module 2 | ===================================== 3 | 4 | .. automodule:: treeflow.debug.minimize_eager 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.debug.nonfinite_convergence_criterion.rst: -------------------------------------------------------------------------------- 1 | treeflow.debug.nonfinite\_convergence\_criterion module 2 | ======================================================= 3 | 4 | .. automodule:: treeflow.debug.nonfinite_convergence_criterion 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.debug.rst: -------------------------------------------------------------------------------- 1 | treeflow.debug package 2 | ====================== 3 | 4 | .. automodule:: treeflow.debug 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.debug.minimize_eager 16 | treeflow.debug.nonfinite_convergence_criterion 17 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.discrete.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.discrete module 2 | ====================================== 3 | 4 | .. automodule:: treeflow.distributions.discrete 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.discrete_parameter_mixture.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.discrete\_parameter\_mixture module 2 | ========================================================== 3 | 4 | .. automodule:: treeflow.distributions.discrete_parameter_mixture 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.discretized.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.discretized module 2 | ========================================= 3 | 4 | .. automodule:: treeflow.distributions.discretized 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.leaf_ctmc.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.leaf\_ctmc module 2 | ======================================== 3 | 4 | .. automodule:: treeflow.distributions.leaf_ctmc 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.markov_chain.linear_gaussian.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.markov\_chain.linear\_gaussian module 2 | ============================================================ 3 | 4 | .. automodule:: treeflow.distributions.markov_chain.linear_gaussian 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.markov_chain.postorder.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.markov\_chain.postorder module 2 | ===================================================== 3 | 4 | .. automodule:: treeflow.distributions.markov_chain.postorder 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.markov_chain.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.markov\_chain package 2 | ============================================ 3 | 4 | .. automodule:: treeflow.distributions.markov_chain 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.distributions.markov_chain.linear_gaussian 16 | treeflow.distributions.markov_chain.postorder 17 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions package 2 | ============================== 3 | 4 | .. automodule:: treeflow.distributions 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.distributions.markov_chain 16 | treeflow.distributions.tree 17 | 18 | Submodules 19 | ---------- 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | 24 | treeflow.distributions.discrete 25 | treeflow.distributions.discrete_parameter_mixture 26 | treeflow.distributions.discretized 27 | treeflow.distributions.leaf_ctmc 28 | treeflow.distributions.sample_weighted 29 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.sample_weighted.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.sample\_weighted module 2 | ============================================== 3 | 4 | .. automodule:: treeflow.distributions.sample_weighted 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.base_tree_distribution.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.base\_tree\_distribution module 2 | =========================================================== 3 | 4 | .. automodule:: treeflow.distributions.tree.base_tree_distribution 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.birthdeath.birth\_death\_contemporary\_sampling module 2 | ================================================================================== 3 | 4 | .. automodule:: treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.birthdeath.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.birthdeath package 2 | ============================================== 3 | 4 | .. automodule:: treeflow.distributions.tree.birthdeath 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling 16 | treeflow.distributions.tree.birthdeath.yule 17 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.birthdeath.yule.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.birthdeath.yule module 2 | ================================================== 3 | 4 | .. automodule:: treeflow.distributions.tree.birthdeath.yule 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.coalescent.constant_coalescent.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.coalescent.constant\_coalescent module 2 | ================================================================== 3 | 4 | .. automodule:: treeflow.distributions.tree.coalescent.constant_coalescent 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.coalescent.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.coalescent package 2 | ============================================== 3 | 4 | .. automodule:: treeflow.distributions.tree.coalescent 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.distributions.tree.coalescent.constant_coalescent 16 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.rooted_tree_distribution.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree.rooted\_tree\_distribution module 2 | ============================================================= 3 | 4 | .. automodule:: treeflow.distributions.tree.rooted_tree_distribution 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.distributions.tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.distributions.tree package 2 | =================================== 3 | 4 | .. automodule:: treeflow.distributions.tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.distributions.tree.birthdeath 16 | treeflow.distributions.tree.coalescent 17 | 18 | Submodules 19 | ---------- 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | 24 | treeflow.distributions.tree.base_tree_distribution 25 | treeflow.distributions.tree.rooted_tree_distribution 26 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.calibration.calibration.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.calibration.calibration module 2 | ================================================= 3 | 4 | .. automodule:: treeflow.evolution.calibration.calibration 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.calibration.mrca.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.calibration.mrca module 2 | ========================================== 3 | 4 | .. automodule:: treeflow.evolution.calibration.mrca 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.calibration.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.calibration package 2 | ====================================== 3 | 4 | .. automodule:: treeflow.evolution.calibration 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.evolution.calibration.calibration 16 | treeflow.evolution.calibration.mrca 17 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution package 2 | ========================== 3 | 4 | .. automodule:: treeflow.evolution 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.evolution.calibration 16 | treeflow.evolution.substitution 17 | 18 | Submodules 19 | ---------- 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | 24 | treeflow.evolution.seqio 25 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.seqio.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.seqio module 2 | =============================== 3 | 4 | .. automodule:: treeflow.evolution.seqio 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.base_substitution_model.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.base\_substitution\_model module 2 | ================================================================ 3 | 4 | .. automodule:: treeflow.evolution.substitution.base_substitution_model 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.eigendecomposition.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.eigendecomposition module 2 | ========================================================= 3 | 4 | .. automodule:: treeflow.evolution.substitution.eigendecomposition 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.nucleotide.alphabet.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.nucleotide.alphabet module 2 | ========================================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.nucleotide.alphabet 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.nucleotide.gtr.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.nucleotide.gtr module 2 | ===================================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.nucleotide.gtr 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.nucleotide.hky.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.nucleotide.hky module 2 | ===================================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.nucleotide.hky 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.nucleotide.jc.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.nucleotide.jc module 2 | ==================================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.nucleotide.jc 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.nucleotide.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.nucleotide package 2 | ================================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.nucleotide 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.evolution.substitution.nucleotide.alphabet 16 | treeflow.evolution.substitution.nucleotide.gtr 17 | treeflow.evolution.substitution.nucleotide.hky 18 | treeflow.evolution.substitution.nucleotide.jc 19 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.probabilities.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.probabilities module 2 | ==================================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.probabilities 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution package 2 | ======================================= 3 | 4 | .. automodule:: treeflow.evolution.substitution 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.evolution.substitution.nucleotide 16 | 17 | Submodules 18 | ---------- 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | 23 | treeflow.evolution.substitution.base_substitution_model 24 | treeflow.evolution.substitution.eigendecomposition 25 | treeflow.evolution.substitution.probabilities 26 | treeflow.evolution.substitution.util 27 | -------------------------------------------------------------------------------- /docs/source/treeflow.evolution.substitution.util.rst: -------------------------------------------------------------------------------- 1 | treeflow.evolution.substitution.util module 2 | =========================================== 3 | 4 | .. automodule:: treeflow.evolution.substitution.util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.approximation.cascading_flows.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.approximation.cascading\_flows module 2 | ==================================================== 3 | 4 | .. automodule:: treeflow.model.approximation.cascading_flows 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.approximation.iaf.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.approximation.iaf module 2 | ======================================= 3 | 4 | .. automodule:: treeflow.model.approximation.iaf 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.approximation.mean_field.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.approximation.mean\_field module 2 | =============================================== 3 | 4 | .. automodule:: treeflow.model.approximation.mean_field 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.approximation.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.approximation package 2 | ==================================== 3 | 4 | .. automodule:: treeflow.model.approximation 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.model.approximation.cascading_flows 16 | treeflow.model.approximation.iaf 17 | treeflow.model.approximation.mean_field 18 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.event_shape_bijector.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.event\_shape\_bijector module 2 | ============================================ 3 | 4 | .. automodule:: treeflow.model.event_shape_bijector 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.io.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.io module 2 | ======================== 3 | 4 | .. automodule:: treeflow.model.io 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.ml.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.ml module 2 | ======================== 3 | 4 | .. automodule:: treeflow.model.ml 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.phylo_model.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.phylo\_model module 2 | ================================== 3 | 4 | .. automodule:: treeflow.model.phylo_model 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.rst: -------------------------------------------------------------------------------- 1 | treeflow.model package 2 | ====================== 3 | 4 | .. automodule:: treeflow.model 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.model.approximation 16 | 17 | Submodules 18 | ---------- 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | 23 | treeflow.model.event_shape_bijector 24 | treeflow.model.io 25 | treeflow.model.ml 26 | treeflow.model.phylo_model 27 | treeflow.model.structured_approximation 28 | -------------------------------------------------------------------------------- /docs/source/treeflow.model.structured_approximation.rst: -------------------------------------------------------------------------------- 1 | treeflow.model.structured\_approximation module 2 | =============================================== 3 | 4 | .. automodule:: treeflow.model.structured_approximation 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.rst: -------------------------------------------------------------------------------- 1 | treeflow package 2 | ================ 3 | 4 | .. automodule:: treeflow 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.acceleration 16 | treeflow.bijectors 17 | treeflow.cli 18 | treeflow.debug 19 | treeflow.distributions 20 | treeflow.evolution 21 | treeflow.model 22 | treeflow.tf_util 23 | treeflow.traversal 24 | treeflow.tree 25 | treeflow.vi 26 | -------------------------------------------------------------------------------- /docs/source/treeflow.tf_util.attrs.rst: -------------------------------------------------------------------------------- 1 | treeflow.tf\_util.attrs module 2 | ============================== 3 | 4 | .. automodule:: treeflow.tf_util.attrs 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tf_util.dtype_util.rst: -------------------------------------------------------------------------------- 1 | treeflow.tf\_util.dtype\_util module 2 | ==================================== 3 | 4 | .. automodule:: treeflow.tf_util.dtype_util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tf_util.linear_operator_upper_triangular.rst: -------------------------------------------------------------------------------- 1 | treeflow.tf\_util.linear\_operator\_upper\_triangular module 2 | ============================================================ 3 | 4 | .. automodule:: treeflow.tf_util.linear_operator_upper_triangular 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tf_util.rst: -------------------------------------------------------------------------------- 1 | treeflow.tf\_util package 2 | ========================= 3 | 4 | .. automodule:: treeflow.tf_util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.tf_util.attrs 16 | treeflow.tf_util.dtype_util 17 | treeflow.tf_util.linear_operator_upper_triangular 18 | treeflow.tf_util.vectorize 19 | -------------------------------------------------------------------------------- /docs/source/treeflow.tf_util.vectorize.rst: -------------------------------------------------------------------------------- 1 | treeflow.tf\_util.vectorize module 2 | ================================== 3 | 4 | .. automodule:: treeflow.tf_util.vectorize 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.traversal.anchor_heights.rst: -------------------------------------------------------------------------------- 1 | treeflow.traversal.anchor\_heights module 2 | ========================================= 3 | 4 | .. automodule:: treeflow.traversal.anchor_heights 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.traversal.phylo_likelihood.rst: -------------------------------------------------------------------------------- 1 | treeflow.traversal.phylo\_likelihood module 2 | =========================================== 3 | 4 | .. automodule:: treeflow.traversal.phylo_likelihood 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.traversal.postorder.rst: -------------------------------------------------------------------------------- 1 | treeflow.traversal.postorder module 2 | =================================== 3 | 4 | .. automodule:: treeflow.traversal.postorder 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.traversal.preorder.rst: -------------------------------------------------------------------------------- 1 | treeflow.traversal.preorder module 2 | ================================== 3 | 4 | .. automodule:: treeflow.traversal.preorder 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.traversal.ratio_transform.rst: -------------------------------------------------------------------------------- 1 | treeflow.traversal.ratio\_transform module 2 | ========================================== 3 | 4 | .. automodule:: treeflow.traversal.ratio_transform 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.traversal.rst: -------------------------------------------------------------------------------- 1 | treeflow.traversal package 2 | ========================== 3 | 4 | .. automodule:: treeflow.traversal 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.traversal.anchor_heights 16 | treeflow.traversal.phylo_likelihood 17 | treeflow.traversal.postorder 18 | treeflow.traversal.preorder 19 | treeflow.traversal.ratio_transform 20 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.base_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.base\_tree module 2 | =============================== 3 | 4 | .. automodule:: treeflow.tree.base_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.io.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.io module 2 | ======================= 3 | 4 | .. automodule:: treeflow.tree.io 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.rooted.base_rooted_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.rooted.base\_rooted\_tree module 2 | ============================================== 3 | 4 | .. automodule:: treeflow.tree.rooted.base_rooted_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.rooted.numpy_rooted_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.rooted.numpy\_rooted\_tree module 2 | =============================================== 3 | 4 | .. automodule:: treeflow.tree.rooted.numpy_rooted_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.rooted.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.rooted package 2 | ============================ 3 | 4 | .. automodule:: treeflow.tree.rooted 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.tree.rooted.base_rooted_tree 16 | treeflow.tree.rooted.numpy_rooted_tree 17 | treeflow.tree.rooted.tensorflow_rooted_tree 18 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.rooted.tensorflow_rooted_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.rooted.tensorflow\_rooted\_tree module 2 | ==================================================== 3 | 4 | .. automodule:: treeflow.tree.rooted.tensorflow_rooted_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree package 2 | ===================== 3 | 4 | .. automodule:: treeflow.tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.tree.rooted 16 | treeflow.tree.topology 17 | treeflow.tree.unrooted 18 | 19 | Submodules 20 | ---------- 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | 25 | treeflow.tree.base_tree 26 | treeflow.tree.io 27 | treeflow.tree.taxon_set 28 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.taxon_set.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.taxon\_set module 2 | =============================== 3 | 4 | .. automodule:: treeflow.tree.taxon_set 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.topology.base_tree_topology.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.topology.base\_tree\_topology module 2 | ================================================== 3 | 4 | .. automodule:: treeflow.tree.topology.base_tree_topology 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.topology.numpy_topology_operations.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.topology.numpy\_topology\_operations module 2 | ========================================================= 3 | 4 | .. automodule:: treeflow.tree.topology.numpy_topology_operations 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.topology.numpy_tree_topology.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.topology.numpy\_tree\_topology module 2 | =================================================== 3 | 4 | .. automodule:: treeflow.tree.topology.numpy_tree_topology 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.topology.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.topology package 2 | ============================== 3 | 4 | .. automodule:: treeflow.tree.topology 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.tree.topology.base_tree_topology 16 | treeflow.tree.topology.numpy_topology_operations 17 | treeflow.tree.topology.numpy_tree_topology 18 | treeflow.tree.topology.tensorflow_tree_topology 19 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.topology.tensorflow_tree_topology.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.topology.tensorflow\_tree\_topology module 2 | ======================================================== 3 | 4 | .. automodule:: treeflow.tree.topology.tensorflow_tree_topology 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.unrooted.base_unrooted_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.unrooted.base\_unrooted\_tree module 2 | ================================================== 3 | 4 | .. automodule:: treeflow.tree.unrooted.base_unrooted_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.unrooted.numpy_unrooted_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.unrooted.numpy\_unrooted\_tree module 2 | =================================================== 3 | 4 | .. automodule:: treeflow.tree.unrooted.numpy_unrooted_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.unrooted.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.unrooted package 2 | ============================== 3 | 4 | .. automodule:: treeflow.tree.unrooted 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.tree.unrooted.base_unrooted_tree 16 | treeflow.tree.unrooted.numpy_unrooted_tree 17 | treeflow.tree.unrooted.tensorflow_unrooted_tree 18 | -------------------------------------------------------------------------------- /docs/source/treeflow.tree.unrooted.tensorflow_unrooted_tree.rst: -------------------------------------------------------------------------------- 1 | treeflow.tree.unrooted.tensorflow\_unrooted\_tree module 2 | ======================================================== 3 | 4 | .. automodule:: treeflow.tree.unrooted.tensorflow_unrooted_tree 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.convergence_criteria.nonfinite.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.convergence\_criteria.nonfinite module 2 | ================================================== 3 | 4 | .. automodule:: treeflow.vi.convergence_criteria.nonfinite 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.convergence_criteria.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.convergence\_criteria package 2 | ========================================= 3 | 4 | .. automodule:: treeflow.vi.convergence_criteria 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.vi.convergence_criteria.nonfinite 16 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.fixed_topology_advi.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.fixed\_topology\_advi module 2 | ======================================== 3 | 4 | .. automodule:: treeflow.vi.fixed_topology_advi 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.marginal_likelihood.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.marginal\_likelihood module 2 | ======================================= 3 | 4 | .. automodule:: treeflow.vi.marginal_likelihood 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.optimizers.robust_optimizer.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.optimizers.robust\_optimizer module 2 | =============================================== 3 | 4 | .. automodule:: treeflow.vi.optimizers.robust_optimizer 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.optimizers.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.optimizers package 2 | ============================== 3 | 4 | .. automodule:: treeflow.vi.optimizers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Submodules 10 | ---------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.vi.optimizers.robust_optimizer 16 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.progress_bar.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.progress\_bar module 2 | ================================ 3 | 4 | .. automodule:: treeflow.vi.progress_bar 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi package 2 | =================== 3 | 4 | .. automodule:: treeflow.vi 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | Subpackages 10 | ----------- 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | treeflow.vi.convergence_criteria 16 | treeflow.vi.optimizers 17 | 18 | Submodules 19 | ---------- 20 | 21 | .. toctree:: 22 | :maxdepth: 1 23 | 24 | treeflow.vi.fixed_topology_advi 25 | treeflow.vi.marginal_likelihood 26 | treeflow.vi.progress_bar 27 | treeflow.vi.util 28 | -------------------------------------------------------------------------------- /docs/source/treeflow.vi.util.rst: -------------------------------------------------------------------------------- 1 | treeflow.vi.util module 2 | ======================= 3 | 4 | .. automodule:: treeflow.vi.util 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/tutorials.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========================== 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | 7 | rates-and-dates -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # TreeFlow examples 2 | 3 | ## H3N2 4 | 5 | [`h3n2-vi.sh`](h3n2-vi.sh) uses TreeFlow's variational inference command line interface to estimate dates and model parameters on an alignment of 980 influenza genomes, taken from: 6 | 7 | > Vaughan, Timothy G., et al. "Efficient Bayesian inference under the structured coalescent." *Bioinformatics* 30.16 (2014): 2272-2279. 8 | 9 | It uses the model specified in [`h3n2-model.yaml`](h3n2-model.yaml). 10 | 11 | ## Rates and dates 12 | [`rates-and-dates.ipynb`](rates-and-dates.ipynb) is a Jupyter notebook that demonstrates TreeFlow's variational inference and model comparison API. We also provide a YAML version of the model definition in [`rates-and-dates-model.yaml`](rates-and-dates.yaml). 13 | 14 | The data and model are based on [the BEAST documentation](https://beast.community/rates_and_dates). The original sequences are taken from: 15 | 16 | > Bryant, Juliet E., Edward C. Holmes, and Alan D. T. Barrett. "Out of Africa: a molecular perspective on the introduction of yellow fever virus into the Americas." *PLoS Pathogens* 3.5 (2007): e75. 17 | 18 | ## Carnivores 19 | 20 | [`carnivores.ipynb`](carnivores.ipynb) is a Jupyter notebook that shows how TreeFlow's probabilistic modelling API can be used for rapid model development. It investigates variation in the transition-tranversion ratio over lineages. 21 | 22 | The dataset is an alignment of mitochondrial DNA sequences from carnivores, [accessed from the BEAST examples](https://github.com/beast-dev/beast-mcmc/blob/v1.10.4/examples/Benchmarks/benchmark2.xml), taken from: 23 | 24 | > Suchard, Marc A., and Andrew Rambaut. "Many-core algorithms for statistical phylogenetics." *Bioinformatics* 25.11 (2009): 1370-1376. 25 | 26 | -------------------------------------------------------------------------------- /examples/demo-data/carnivores.newick: -------------------------------------------------------------------------------- 1 | ((((((Puma_concolor_:0.05844,Acinonyx_jubatus_:0.05844):0.01360,Lynx_canadensis_:0.07203):0.01447,Felis_silvestris_:0.08650):0.03656,(((Uncia_uncia_:0.04698,Panthera_pardus_:0.04698):0.01748,Panthera_tigris_:0.06446):0.02085,Neofelis_nebulosa_:0.08531):0.03775):0.17818,Herpestes_auropunctatus_:0.30124):0.16816,((((((((((Leptonychotes_weddellii_:0.03246,Hydrurga_leptonyx_:0.03246):0.02098,Ommatophoca_rossii_:0.05344):0.00319,Lobodon_carcinophaga_:0.05663):0.02926,(Mirounga_leonina_:0.01815,Mirounga_angustirostris_:0.01815):0.06774):0.01934,(Monachus_monachus_:0.08867,Monachus_schauinslandi_:0.08867):0.01656):0.04898,((((Phoca_fasciata_:0.03805,Phoca_groenlandica_:0.03805):0.01418,((((Phoca_caspica_:0.01401,Halichoerus_grypus_:0.01401):0.00023,Phoca_sibirica_:0.01424):0.00235,(Phoca_largha_:0.00837,Phoca_vitulina_:0.00837):0.00821):0.00129,Phoca_hispida_:0.01788):0.03434):0.01376,Cystophora_cristata_:0.06598):0.04440,Erignathus_barbatus_:0.11038):0.04382):0.10038,((((((((Arctocephalus_forsteri_:0.01159,Arctocephalus_australis_:0.01159):0.01965,Arctocephalus_townsendi_:0.03124):0.01040,(Neophoca_cinerea_:0.04047,Phocarctos_hookeri_:0.04047):0.00117):0.00259,Arctocephalus_pusillus_:0.04423):0.00171,Otaria_byronia_:0.04594):0.00769,(Zalophus_californianus_:0.03912,Eumetopias_jubatus_:0.03912):0.01451):0.03942,Callorhinus_ursinus_:0.09305):0.11460,Odobenus_rosmarus_:0.20765):0.04693):0.13094,(((((((Enhydra_lutris_:0.12728,Lontra_canadensis_:0.12728):0.03222,Mustela_vison_:0.15950):0.00466,((Martes_melampus_:0.02638,Martes_americana_:0.02638):0.07874,Gulo_gulo_:0.10512):0.05905):0.01360,Meles_meles_:0.17776):0.02164,Taxidea_taxus_:0.19940):0.12468,Procyon_lotor_:0.32409):0.04397,((Mephitis_mephitis_:0.12282,Spilogale_putorius_:0.12282):0.23149,Ailurus_fulgens_:0.35431):0.01374):0.01747):0.00000,((((((Ursus_thibetanus_:0.04164,Ursus_americanus_:0.04164):0.00647,Helarctos_malayanus_:0.04811):0.00540,(Ursus_arctos_:0.00992,Ursus_maritimus_:0.00992):0.04359):0.00892,Melursus_ursinus_:0.06243):0.08279,Tremarctos_ornatus_:0.14522):0.05334,Ailuropoda_melanoleuca_:0.19856):0.18696):0.07159,((Canis_latrans_:0.02980,Canis_lupus_:0.02980):0.13993,(Alopex_lagopus_:0.05467,Vulpes_vulpes_:0.05467):0.11506):0.28737):0.01230):0.00000; 2 | -------------------------------------------------------------------------------- /examples/h3n2-model.yaml: -------------------------------------------------------------------------------- 1 | clock: 2 | strict: 3 | clock_rate: 4 | lognormal: 5 | loc: -2.0 6 | scale: 2.0 7 | site: 8 | discrete_gamma: 9 | category_count: 4 10 | site_gamma_shape: 11 | lognormal: 12 | loc: 0.0 13 | scale: 1.0 14 | substitution: 15 | gtr_rel: 16 | frequencies: 17 | dirichlet: 18 | concentration: 19 | - 2.0 20 | - 2.0 21 | - 2.0 22 | - 2.0 23 | rate_ac: 24 | gamma: 25 | concentration: 0.05 26 | rate: 0.05 27 | rate_ag: 28 | gamma: 29 | concentration: 0.05 30 | rate: 0.05 31 | rate_at: 32 | gamma: 33 | concentration: 0.05 34 | rate: 0.05 35 | rate_cg: 36 | gamma: 37 | concentration: 0.05 38 | rate: 0.05 39 | rate_gt: 40 | gamma: 41 | concentration: 0.05 42 | rate: 0.05 43 | tree: 44 | coalescent: 45 | pop_size: 46 | lognormal: 47 | loc: 1.0 48 | scale: 1.5 49 | -------------------------------------------------------------------------------- /examples/h3n2-vi.sh: -------------------------------------------------------------------------------- 1 | treeflow_vi -s 1 \ 2 | -i demo-data/h3n2.fasta \ 3 | -m h3n2-model.yaml \ 4 | -t demo-data/h3n2.nwk \ 5 | -n 30000 \ 6 | --learning-rate 0.001 \ 7 | --init-values "clock_rate=0.003" \ 8 | --trace-output demo-out/h3n2-trace.pickle \ 9 | --samples-output demo-out/h3n2-samples.csv \ 10 | --tree-samples-output demo-out/h3n2-trees.nexus \ 11 | --n-output-samples 1000 -------------------------------------------------------------------------------- /examples/rates-and-dates-model.yaml: -------------------------------------------------------------------------------- 1 | clock: 2 | strict: 3 | clock_rate: 4 | lognormal: 5 | loc: -2.0 6 | scale: 2.0 7 | site: 8 | discrete_gamma: 9 | category_count: 4 10 | site_gamma_shape: 11 | lognormal: 12 | loc: 0.0 13 | scale: 1.0 14 | substitution: 15 | hky: 16 | kappa: 17 | lognormal: 18 | loc: 1.0 19 | scale: 1.25 20 | frequencies: 21 | dirichlet: 22 | concentration: 23 | - 2.0 24 | - 2.0 25 | - 2.0 26 | - 2.0 27 | tree: 28 | coalescent: 29 | pop_size: 30 | lognormal: 31 | loc: 1.0 32 | scale: 1.5 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools >= 56.0"] -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs = test/fixtures/ test/helpers/ scratch/* -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==23.2.0 2 | click==8.1.3 3 | DendroPy==4.5.2 4 | ete3==3.1.2 5 | jupyterlab==4.1.3 6 | numpy==1.22.3 7 | pandas==1.5.0 8 | PyYAML==6.0 9 | silence-tensorflow==1.2.1 10 | tensorflow==2.11.0 11 | tensorflow-estimator==2.11.0 12 | tensorflow-io-gcs-filesystem==0.25.0 13 | tensorflow-probability==0.19.0 14 | tqdm==4.64.0 15 | typing_extensions==4.2.0 16 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = treeflow 3 | version = 0.0.1 4 | url = https://github.com/christiaanjs/treeflow 5 | author = Christiaan Swanepoel 6 | author_email = christiaan.j.s@gmail.com 7 | keywords = phylogenetics, tensorflow 8 | description = Phylogenetics in Tensorflow 9 | long_description_content_type = text/markdown 10 | license = GPL3 11 | license_file = LICENSE 12 | classifiers = 13 | Intended Audience :: Science/Research 14 | License :: OSI Approved :: GNU General Public License v3 (GPLv3) 15 | Operating System :: OS Independent 16 | Programming Language :: Python :: 3.7 17 | Programming Language :: Python :: 3.8 18 | Programming Language :: Python :: 3.9 19 | Topic :: Scientific/Engineering :: Bio-Informatics 20 | 21 | 22 | [options] 23 | python_requires = >=3.7 24 | packages = find: 25 | package_dir = 26 | =. 27 | install_requires = 28 | tensorflow>=2.11.0 29 | tensorflow_probability>=0.19.0 30 | numpy>=1.19 31 | ete3>=3.1.2 32 | attrs>=21.1.0 33 | PyYAML>=6.0 34 | dendropy>=4.5.2 35 | click>=8.1.2 36 | tqdm>=4.64.0 37 | silence_tensorflow>=1.2.1 38 | test_requires = 39 | pandas>=1.3.5 40 | pytest>=7.1.2 41 | 42 | [options.entry_points] 43 | console_scripts = 44 | treeflow_benchmark = treeflow.cli.benchmark:treeflow_benchmark 45 | treeflow_vi = treeflow.cli.vi:treeflow_vi 46 | treeflow_ml = treeflow.cli.ml:treeflow_ml 47 | 48 | 49 | [options.extras_require] 50 | benchmark = 51 | memory_profiler 52 | test = 53 | pandas>=1.3.5 54 | pytest>=7.1.2 55 | accelerated = 56 | bito -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() 5 | -------------------------------------------------------------------------------- /test/acceleration/bito/test_beagle.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from numpy.testing import assert_allclose 3 | import tensorflow as tf 4 | from treeflow.evolution.seqio import Alignment 5 | from treeflow.evolution.substitution.nucleotide.hky import HKY 6 | from treeflow.evolution.substitution.probabilities import ( 7 | get_transition_probabilities_tree, 8 | ) 9 | from treeflow.tree.io import parse_newick 10 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 11 | from tensorflow_probability.python.distributions import Sample 12 | from treeflow.distributions.leaf_ctmc import LeafCTMC 13 | 14 | 15 | def test_log_prob_conditioned_hky(hky_params, newick_fasta_file_dated): 16 | from treeflow.acceleration.bito.beagle import ( 17 | phylogenetic_likelihood as beagle_likelihood, 18 | ) 19 | 20 | newick_file, fasta_file, dated = newick_fasta_file_dated 21 | subst_model = HKY() 22 | tensor_tree = convert_tree_to_tensor(parse_newick(newick_file)).get_unrooted_tree() 23 | alignment = Alignment(fasta_file) 24 | sequences = alignment.get_encoded_sequence_tensor(tensor_tree.taxon_set) 25 | treeflow_func = lambda blens: Sample( 26 | LeafCTMC( 27 | get_transition_probabilities_tree( 28 | tensor_tree.with_branch_lengths(blens), subst_model, **hky_params 29 | ), 30 | hky_params["frequencies"], 31 | ), 32 | sample_shape=alignment.site_count, 33 | ).log_prob(sequences) 34 | 35 | beagle_func, _ = beagle_likelihood( 36 | fasta_file, subst_model, newick_file=newick_file, dated=dated, **hky_params 37 | ) 38 | 39 | blens = tensor_tree.branch_lengths 40 | with tf.GradientTape() as tf_t: 41 | tf_t.watch(blens) 42 | tf_ll = treeflow_func(blens) 43 | tf_gradient = tf_t.gradient(tf_ll, blens) 44 | 45 | with tf.GradientTape() as bito_t: 46 | bito_t.watch(blens) 47 | bito_ll = beagle_func(blens) 48 | bito_gradient = bito_t.gradient(bito_ll, blens) 49 | 50 | assert_allclose(tf_ll.numpy(), bito_ll.numpy()) 51 | assert_allclose(tf_gradient.numpy(), bito_gradient.numpy()) 52 | -------------------------------------------------------------------------------- /test/acceleration/bito/test_bito_instance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from treeflow.tree.io import parse_newick 4 | from treeflow.traversal.anchor_heights import get_anchor_heights 5 | from numpy.testing import assert_allclose, assert_equal 6 | 7 | 8 | # TODO: Test taxon name order 9 | 10 | def test_get_tree_info(newick_file_dated): 11 | from treeflow.acceleration.bito.instance import get_instance, get_tree_info 12 | 13 | newick_file, dated = newick_file_dated 14 | treeflow_tree = parse_newick(newick_file) 15 | treeflow_node_bounds = get_anchor_heights(treeflow_tree) 16 | 17 | inst = get_instance(newick_file, dated=dated) 18 | bito_tree, bito_node_bounds = get_tree_info(inst) 19 | 20 | assert_equal( 21 | treeflow_tree.topology.parent_indices, 22 | bito_tree.topology.parent_indices, 23 | ) 24 | assert_allclose(treeflow_tree.heights, bito_tree.heights, atol=1e-12) 25 | assert_allclose(treeflow_node_bounds, bito_node_bounds, atol=1e-12) 26 | -------------------------------------------------------------------------------- /test/acceleration/bito/test_bito_ratio_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import pytest 4 | from treeflow.bijectors.node_height_ratio_bijector import NodeHeightRatioBijector 5 | from treeflow.tree.io import parse_newick 6 | from treeflow.traversal.anchor_heights import get_anchor_heights 7 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 8 | from functools import partial 9 | from numpy.testing import assert_allclose 10 | 11 | 12 | def get_bito_forward_func(newick_file, dated): 13 | from treeflow.acceleration.bito.instance import get_instance, get_tree_info 14 | from treeflow.acceleration.bito.ratio_transform import ( 15 | ratios_to_node_heights as bito_ratios_to_node_heights, 16 | ) 17 | 18 | inst = get_instance(newick_file, dated=dated) 19 | tree, anchor_heights = get_tree_info(inst) 20 | def forward_func(ratios): 21 | return bito_ratios_to_node_heights(inst, anchor_heights, ratios) 22 | return forward_func 23 | 24 | 25 | def get_treeflow_forward_func(numpy_tree, tensor_constant): 26 | anchor_heights = tensor_constant(get_anchor_heights(numpy_tree)) 27 | return NodeHeightRatioBijector( 28 | convert_tree_to_tensor(numpy_tree).topology, anchor_heights 29 | ).forward 30 | 31 | 32 | def get_ratios(taxon_count): 33 | return (np.arange(taxon_count - 1) + 1.0) / taxon_count 34 | 35 | 36 | def get_test_values(forward_func, ratios): 37 | with tf.GradientTape() as t: 38 | t.watch(ratios) 39 | heights = forward_func(ratios) 40 | res = tf.reduce_sum(heights ** 2) 41 | 42 | grad = t.gradient(res, ratios) 43 | return heights, grad 44 | 45 | 46 | def test_bito_ratio_transform_forward(newick_file_dated, tensor_constant): 47 | newick_file, dated = newick_file_dated 48 | numpy_tree = parse_newick(newick_file) 49 | bito_forward_func = get_bito_forward_func(newick_file, dated) 50 | treeflow_forward_func = get_treeflow_forward_func(numpy_tree, tensor_constant) 51 | ratios = tensor_constant(get_ratios(numpy_tree.taxon_count)) 52 | 53 | treeflow_heights, treeflow_grad = get_test_values(treeflow_forward_func, ratios) 54 | bito_heights, bito_grad = get_test_values(bito_forward_func, ratios) 55 | 56 | assert_allclose(bito_heights.numpy(), treeflow_heights.numpy()) 57 | assert_allclose(bito_grad.numpy(), treeflow_grad.numpy()) 58 | -------------------------------------------------------------------------------- /test/bijectors/test_highway_flow_node_bijector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import attr 4 | import tensorflow as tf 5 | from treeflow.tree.rooted.tensorflow_rooted_tree import TensorflowRootedTree 6 | from treeflow.model.approximation.cascading_flows import ( 7 | DEFAULT_ACTIVATION_FUNCTIONS, 8 | get_trainable_highway_flow_parameters, 9 | ) 10 | from treeflow.bijectors.highway_flow import ( 11 | HighwayFlowParameters, 12 | HIGHWAY_FLOW_PARAMETER_EVENT_NDIMS, 13 | ) 14 | from treeflow.bijectors.highway_flow_node_bijector import HighwayFlowNodeBijector 15 | 16 | 17 | def test_highway_flow_node_bijector_shapes(hello_tensor_tree: TensorflowRootedTree): 18 | tree = hello_tensor_tree 19 | dtype = tree.node_heights.dtype 20 | activation_functions = DEFAULT_ACTIVATION_FUNCTIONS 21 | 22 | batch_shape = tf.stack([len(activation_functions), tree.taxon_count - 1]) 23 | 24 | parameters = HighwayFlowParameters( 25 | **attr.asdict( 26 | get_trainable_highway_flow_parameters( 27 | 2, batch_shape, dtype=dtype, defer=False 28 | ) 29 | ) 30 | ) 31 | 32 | flow_bijector = HighwayFlowNodeBijector( 33 | tree.topology, 34 | parameters, 35 | (), 36 | flow_parameter_event_ndims=HIGHWAY_FLOW_PARAMETER_EVENT_NDIMS, 37 | activation_functions=activation_functions, 38 | ) 39 | sample_shape = (4,) 40 | tree_shape = (tree.taxon_count,) 41 | event_shape = (1,) 42 | shape = sample_shape + tree_shape + event_shape 43 | base_value = tf.zeros(shape, dtype=dtype) 44 | forward = flow_bijector.forward(base_value) 45 | inverse = flow_bijector.inverse(forward) 46 | assert inverse.numpy().shape == shape 47 | -------------------------------------------------------------------------------- /test/bijectors/test_preorder_node_bijector.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 3 | import numpy as np 4 | from numpy.testing import assert_allclose 5 | 6 | from tensorflow_probability.python.bijectors import Chain, Shift, Scale 7 | from treeflow_test_helpers.ratio_helpers import ( 8 | topology_from_ratio_test_data, 9 | RatioTestData, 10 | ) 11 | from treeflow.bijectors.preorder_node_bijector import PreorderNodeBijector 12 | from treeflow.bijectors.node_height_ratio_bijector import NodeHeightRatioBijector 13 | 14 | 15 | def ratio_transform_forward_mapping(parent_height, anchor_height): 16 | return Chain([Shift(anchor_height), Scale(parent_height - anchor_height)]) 17 | 18 | 19 | def ratio_transform_forward_root_mapping(anchor_height): 20 | return Shift(anchor_height) 21 | 22 | 23 | def get_bijector(ratio_test_data: RatioTestData) -> PreorderNodeBijector: 24 | topology = topology_from_ratio_test_data(ratio_test_data) 25 | bijector = PreorderNodeBijector( 26 | topology, 27 | ratio_test_data.anchor_heights, 28 | ratio_transform_forward_mapping, 29 | ratio_transform_forward_root_mapping, 30 | ) 31 | return bijector 32 | 33 | 34 | def test_preorder_node_bijector_forward(ratio_test_data: RatioTestData): 35 | bijector = get_bijector(ratio_test_data) 36 | res = bijector.forward(ratio_test_data.ratios) 37 | expected = ratio_test_data.heights 38 | assert_allclose(res.numpy(), expected) 39 | 40 | 41 | def test_preorder_node_bijector_forward_log_det_jacobian( 42 | ratio_test_data: RatioTestData, 43 | ): 44 | bijector = get_bijector(ratio_test_data) 45 | res = bijector.forward_log_det_jacobian(ratio_test_data.ratios) 46 | 47 | test_bijector = NodeHeightRatioBijector( 48 | bijector._topology, ratio_test_data.anchor_heights 49 | ) 50 | expected = test_bijector.forward_log_det_jacobian(ratio_test_data.ratios) 51 | 52 | assert_allclose(res.numpy(), expected.numpy()) 53 | 54 | 55 | def test_preorder_node_bijector_inverse(ratio_test_data: RatioTestData): 56 | bijector = get_bijector(ratio_test_data) 57 | res = bijector.inverse(tf.constant(ratio_test_data.heights, DEFAULT_FLOAT_DTYPE_TF)) 58 | assert_allclose(res.numpy(), ratio_test_data.ratios) 59 | -------------------------------------------------------------------------------- /test/cli/test_benchmark.py: -------------------------------------------------------------------------------- 1 | from operator import gt 2 | import pytest 3 | from treeflow.cli.benchmark import treeflow_benchmark 4 | from click.testing import CliRunner 5 | 6 | 7 | @pytest.fixture 8 | def benchmark_output_path(tmp_path): 9 | return tmp_path / "treeflow-benchmark.csv" 10 | 11 | 12 | @pytest.mark.parametrize(["use_bito", "gtr"], [(True, False), (False, True)]) 13 | def test_benchmark( 14 | hello_fasta_file, hello_newick_file, benchmark_output_path, use_bito, gtr 15 | ): 16 | runner = CliRunner() 17 | args = [ 18 | "-i", 19 | hello_fasta_file, 20 | "-t", 21 | hello_newick_file, 22 | "-r", 23 | str(1), 24 | "-o", 25 | str(benchmark_output_path), 26 | ] 27 | if use_bito: 28 | args.append("--use-bito") 29 | if gtr: 30 | args.append("--gtr") 31 | res = runner.invoke( 32 | treeflow_benchmark, 33 | args, 34 | catch_exceptions=False, 35 | ) 36 | print(res) 37 | -------------------------------------------------------------------------------- /test/cli/test_ml_cli.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from click.testing import CliRunner 3 | from treeflow.cli.ml import treeflow_ml 4 | 5 | 6 | @pytest.mark.parametrize("include_init_values", [True, False]) 7 | def test_ml_cli( 8 | newick_fasta_file_dated, 9 | include_init_values, 10 | model_file, 11 | samples_output_path, 12 | tree_samples_output_path, 13 | trace_output_path, 14 | ): 15 | import pandas as pd 16 | import dendropy 17 | 18 | newick_file, fasta_file, dated = newick_fasta_file_dated 19 | runner = CliRunner() 20 | args = [ 21 | "-i", 22 | str(fasta_file), 23 | "-t", 24 | str(newick_file), 25 | "-n", 26 | str(10), 27 | "--variables-output", 28 | str(samples_output_path), 29 | "--tree-output", 30 | str(tree_samples_output_path), 31 | "--trace-output", 32 | str(trace_output_path), 33 | ] 34 | if model_file is None: 35 | init_values_string = "clock_rate=0.01" 36 | else: 37 | args = args + ["-m", model_file] 38 | init_values_string = "pop_size=10" 39 | 40 | if include_init_values: 41 | args = args + ["--init-values", init_values_string] 42 | res = runner.invoke( 43 | treeflow_ml, 44 | args, 45 | catch_exceptions=False, 46 | ) 47 | assert res.exit_code == 0 48 | print(res.stdout) 49 | samples = pd.read_csv(samples_output_path) 50 | assert samples.shape[0] == 1 51 | 52 | trees = dendropy.TreeList.get(path=tree_samples_output_path, schema="nexus") 53 | assert len(trees) == 1 54 | -------------------------------------------------------------------------------- /test/data/beast-test-case.nwk: -------------------------------------------------------------------------------- 1 | ((((human:0.024003,(chimp:0.010772,bonobo:0.010772):0.013231):0.012035,gorilla:0.036038):0.033087000000000005,orangutan:0.069125):0.030456999999999998,siamang:0.099582); -------------------------------------------------------------------------------- /test/data/hello.fasta: -------------------------------------------------------------------------------- 1 | >mars 2 | CCGAG-AGCAGCAATGGAT-GAGGCATGGCG 3 | >saturn 4 | GCGCGCAGCTGCTGTAGATGGAGGCATGACG 5 | >jupiter 6 | GCGCGCAGCAGCTGTGGATGGAAGGATGACG 7 | -------------------------------------------------------------------------------- /test/data/hello.nwk: -------------------------------------------------------------------------------- 1 | ((mars:0.1,saturn:0.1):0.2,jupiter:0.3); 2 | -------------------------------------------------------------------------------- /test/data/model.yaml: -------------------------------------------------------------------------------- 1 | tree: 2 | coalescent: 3 | pop_size: 4 | lognormal: 5 | loc: 0.0 6 | scale: 2.2 7 | clock: 8 | relaxed_lognormal: 9 | branch_rate_loc: 10 | lognormal: 11 | loc: -7.6 12 | scale: 1.0 13 | branch_rate_scale: 14 | lognormal: 15 | loc: -1.4 16 | scale: 0.45 17 | substitution: 18 | hky: 19 | kappa: 20 | gamma: 21 | concentration: 2.0 22 | rate: 1.0 23 | frequencies: 24 | dirichlet: 25 | concentration: [4.0, 4.0, 4.0, 4.0] 26 | site: 27 | discrete_gamma: 28 | site_gamma_shape: 29 | gamma: 30 | concentration: 3.0 31 | rate: 3.0 32 | category_count: 4 33 | -------------------------------------------------------------------------------- /test/data/tree-sim.newick: -------------------------------------------------------------------------------- 1 | ((((((T10_0.9473684210526315:0.06786,(T12_1.1578947368421053:0.08314,(T16_1.5789473684210527:0.08301,T19_1.894736842105263:0.39880):0.42118):0.19524):0.58492,T7_0.631578947368421:0.33699):0.06914,(T11_1.0526315789473684:0.13027,((T17_1.6842105263157894:0.02456,T20_2.0:0.34035):0.04107,T18_1.789473684210526:0.17089):0.69621):0.69691):0.02538,(T13_1.263157894736842:0.26520,T14_1.3684210526315788:0.37046):0.79788):0.91598,T1_0.0:0.71589):2.19241,((((T15_1.4736842105263157:0.75595,T8_0.7368421052631579:0.01910):0.42619,T9_0.8421052631578947:0.55055):0.03637,T4_0.3157894736842105:0.06061):1.82915,((T2_0.10526315789473684:0.02213,T5_0.42105263157894735:0.33792):0.13404,(T3_0.21052631578947367:0.14991,T6_0.5263157894736842:0.46570):0.11153):1.52306):1.33433):0.00000; 2 | -------------------------------------------------------------------------------- /test/data/yule-model.yaml: -------------------------------------------------------------------------------- 1 | tree: 2 | yule: 3 | birth_rate: 4 | lognormal: 5 | loc: 0.0 6 | scale: 2.2 7 | clock: 8 | strict: 9 | clock_rate: 10 | lognormal: 11 | loc: -7.6 12 | scale: 1.0 13 | substitution: 14 | hky: 15 | kappa: 16 | lognormal: 17 | loc: 0.0 18 | scale: 2.0 19 | frequencies: 20 | dirichlet: 21 | concentration: [2.0, 2.0, 2.0, 2.0] 22 | site: 23 | discrete_weibull: 24 | site_weibull_concentration: 25 | gamma: 26 | concentration: 3.0 27 | rate: 3.0 28 | category_count: 4 29 | -------------------------------------------------------------------------------- /test/debug_util/test_minimize_eager.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.optimizer.convergence_criteria.loss_not_decreasing import ( 3 | LossNotDecreasing, 4 | ) 5 | import tensorflow as tf 6 | from tensorflow_probability.python.math.minimize import minimize as tfp_minimize 7 | from treeflow.debug.minimize_eager import minimize_eager 8 | from treeflow_test_helpers.optimization_helpers import obj, vars, optimizer_builder 9 | from numpy.testing import assert_allclose 10 | 11 | 12 | def criterion_builder(): 13 | return LossNotDecreasing(atol=1e-3, min_num_steps=5, window_size=3) 14 | 15 | 16 | def test_minimize_eager_matches(): 17 | optimizer = optimizer_builder() 18 | other_optimizer = optimizer_builder() 19 | 20 | _vars = vars() 21 | other_vars = vars("other_") 22 | 23 | _obj = obj(_vars) 24 | other_obj = obj(other_vars) 25 | 26 | trace_fn = lambda x: (x.step, x.loss, x.parameters) 27 | 28 | num_steps = 10 29 | eager_res = minimize_eager(_obj, num_steps, optimizer, trace_fn=trace_fn) 30 | other_res = tfp_minimize(other_obj, num_steps, other_optimizer, trace_fn=trace_fn) 31 | 32 | tf.nest.map_structure(assert_allclose, eager_res, other_res) 33 | 34 | 35 | def test_minimize_eager_convergence(): 36 | convergence_criterion = criterion_builder() 37 | trace_fn = lambda x: (x.loss, x.has_converged, x.convergence_criterion_state) 38 | num_steps = 1000 39 | 40 | optimizer = optimizer_builder() 41 | _vars = vars() 42 | _obj = obj(_vars) 43 | eager_res = minimize_eager( 44 | _obj, 45 | num_steps, 46 | optimizer, 47 | trace_fn=trace_fn, 48 | batch_convergence_reduce_fn=tf.reduce_any, 49 | convergence_criterion=convergence_criterion, 50 | ) 51 | 52 | other_optimizer = optimizer_builder() 53 | other_vars = vars("other_") 54 | other_obj = obj(other_vars) 55 | other_res = tfp_minimize( 56 | other_obj, 57 | num_steps, 58 | other_optimizer, 59 | trace_fn=trace_fn, 60 | return_full_length_trace=False, 61 | batch_convergence_reduce_fn=tf.reduce_any, 62 | convergence_criterion=convergence_criterion, 63 | ) 64 | 65 | tf.nest.map_structure(assert_allclose, eager_res, other_res) 66 | -------------------------------------------------------------------------------- /test/debug_util/test_nonfinite_convergence_criterion.py: -------------------------------------------------------------------------------- 1 | from tensorflow_probability.python.math.minimize import minimize as tfp_minimize 2 | from treeflow_test_helpers.optimization_helpers import obj, vars, optimizer_builder 3 | from treeflow.debug.nonfinite_convergence_criterion import NonfiniteConvergenceCriterion 4 | 5 | # TODO: More comprehensive tests 6 | 7 | 8 | def test_NonfiniteConvergenceCriterion_runs(): 9 | optimizer = optimizer_builder() 10 | convergence_criterion = NonfiniteConvergenceCriterion() 11 | 12 | _vars = vars() 13 | _obj = obj(_vars) 14 | 15 | res = tfp_minimize( 16 | _obj, 17 | 10, 18 | optimizer, 19 | return_full_length_trace=False, 20 | convergence_criterion=convergence_criterion, 21 | ) 22 | -------------------------------------------------------------------------------- /test/distributions/markov_chain/test_linear_gaussian.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.distributions import Independent 3 | from treeflow.distributions.markov_chain.postorder import PostorderNodeMarkovChain 4 | from treeflow.distributions.markov_chain.linear_gaussian import ( 5 | LinearGaussianPostorderNodeMarkovChain, 6 | ) 7 | from treeflow.tree.rooted.tensorflow_rooted_tree import TensorflowRootedTree 8 | from numpy.testing import assert_allclose 9 | 10 | 11 | def test_LinearGaussianPostorderNodeMarkovChain_sample( 12 | hello_tensor_tree: TensorflowRootedTree, 13 | ): 14 | float_dtype = hello_tensor_tree.node_heights.dtype 15 | scale = tf.constant(0.1, dtype=float_dtype) 16 | node_means = tf.random.normal( 17 | tf.expand_dims(hello_tensor_tree.taxon_count - 1, 0), seed=1, dtype=float_dtype 18 | ) 19 | dist = LinearGaussianPostorderNodeMarkovChain( 20 | hello_tensor_tree.topology, node_means, scale 21 | ) 22 | sample_shape = (4,) 23 | res = dist.sample(sample_shape, seed=2) 24 | assert res.numpy().shape == sample_shape + (hello_tensor_tree.taxon_count - 1,) 25 | 26 | 27 | def test_LinearGaussianPostorderNodeMarkovChain_log_prob( 28 | hello_tensor_tree: TensorflowRootedTree, 29 | ): 30 | float_dtype = hello_tensor_tree.node_heights.dtype 31 | scale = tf.constant(0.1, dtype=float_dtype) 32 | node_means = tf.random.normal( 33 | tf.expand_dims(hello_tensor_tree.taxon_count - 1, 0), seed=1, dtype=float_dtype 34 | ) 35 | dist = LinearGaussianPostorderNodeMarkovChain( 36 | hello_tensor_tree.topology, node_means, scale 37 | ) 38 | sample_shape = (4,) 39 | sample = dist.sample(sample_shape, seed=2) 40 | 41 | log_prob = dist.log_prob(sample) 42 | expected_log_prob = PostorderNodeMarkovChain.log_prob(dist, sample) 43 | assert_allclose(log_prob.numpy(), expected_log_prob.numpy()) 44 | 45 | 46 | def test_LinearGaussianPostorderNodeMarkovChain_sample_and_log_prob_batch( 47 | hello_tensor_tree: TensorflowRootedTree, 48 | ): 49 | float_dtype = hello_tensor_tree.node_heights.dtype 50 | scale = tf.constant(0.1, dtype=float_dtype) 51 | batch_shape = (4,) 52 | node_means = tf.random.normal( 53 | batch_shape + (hello_tensor_tree.taxon_count - 1,), seed=1, dtype=float_dtype 54 | ) 55 | dist = Independent( 56 | LinearGaussianPostorderNodeMarkovChain( 57 | hello_tensor_tree.topology, node_means, scale 58 | ), 59 | reinterpreted_batch_ndims=len(batch_shape), 60 | ) 61 | sample_shape = (5,) 62 | sample = dist.sample(sample_shape, seed=2) 63 | assert sample.numpy().shape == sample_shape + batch_shape + ( 64 | hello_tensor_tree.taxon_count - 1, 65 | ) 66 | 67 | log_prob = dist.log_prob(sample) 68 | assert log_prob.numpy().shape == sample_shape 69 | -------------------------------------------------------------------------------- /test/distributions/markov_chain/test_postorder_node_markov_chain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow_probability.python.distributions import Normal 4 | from treeflow.tree.rooted.tensorflow_rooted_tree import TensorflowRootedTree 5 | from treeflow.distributions.markov_chain.postorder import PostorderNodeMarkovChain 6 | 7 | 8 | def test_PostorderNodeMarkovChain_sample(hello_tensor_tree: TensorflowRootedTree): 9 | float_dtype = hello_tensor_tree.node_heights.dtype 10 | scale = tf.constant(0.1, dtype=float_dtype) 11 | node_means = tf.random.normal( 12 | tf.expand_dims(hello_tensor_tree.taxon_count - 1, 0), seed=1, dtype=float_dtype 13 | ) 14 | dist = PostorderNodeMarkovChain( 15 | hello_tensor_tree.topology, 16 | lambda input, children: Normal(input + tf.reduce_sum(children, axis=0), scale), 17 | node_means, 18 | childless_init=tf.zeros((0,), float_dtype), 19 | ) 20 | sample_shape = (4,) 21 | res = dist.sample(sample_shape, seed=2) 22 | assert res.numpy().shape == sample_shape + (hello_tensor_tree.taxon_count - 1,) 23 | 24 | 25 | def test_PostorderNodeMarkovChain_log_prob(hello_tensor_tree: TensorflowRootedTree): 26 | float_dtype = hello_tensor_tree.node_heights.dtype 27 | scale = tf.constant(0.1, dtype=float_dtype) 28 | node_means = tf.random.normal( 29 | tf.expand_dims(hello_tensor_tree.taxon_count - 1, 0), seed=1, dtype=float_dtype 30 | ) 31 | dist = PostorderNodeMarkovChain( 32 | hello_tensor_tree.topology, 33 | lambda input, children: Normal(input + tf.reduce_sum(children, axis=0), scale), 34 | node_means, 35 | childless_init=tf.zeros((0,), float_dtype), 36 | ) 37 | sample_shape = (4,) 38 | samples = dist.sample(sample_shape, seed=2) 39 | res = dist.log_prob(samples) 40 | assert res.numpy().shape == sample_shape 41 | assert all(np.isfinite(res.numpy())) 42 | -------------------------------------------------------------------------------- /test/distributions/test_discretized.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | from tensorflow_probability.python.distributions import Normal, Gamma 4 | from treeflow.distributions.discretized import DiscretizedDistribution 5 | from numpy.testing import assert_allclose 6 | 7 | 8 | def test_discretized_log_prob(): 9 | base_dist = Normal(2.0, 2.0) 10 | k = 5 11 | discretized = DiscretizedDistribution(k, base_dist) 12 | quantiles = discretized.support 13 | other_length = 2 14 | first_others = quantiles[:other_length] - 1e-3 15 | last_others = quantiles[-other_length:] + 1e-3 16 | x = tf.concat([first_others, quantiles, last_others], axis=0) 17 | res = discretized.prob(x).numpy() 18 | assert res.shape == (2 * other_length + k,) 19 | mass = 1.0 / k 20 | assert_allclose(res[:other_length], 0.0) 21 | assert_allclose(res[-other_length:], 0.0) 22 | assert_allclose(res[other_length:-other_length], mass) 23 | 24 | 25 | @pytest.mark.parametrize("function_mode", [True, False]) 26 | def test_discretized_sample(function_mode): 27 | sample_shape = (3, 2) 28 | k = 6 29 | mass = 1.0 / k 30 | 31 | def sample_and_prob_func(k): 32 | base_dist = Gamma(2.0, 2.0) 33 | discretized = DiscretizedDistribution(k, base_dist) 34 | sample = discretized.sample(sample_shape, seed=1) 35 | prob = discretized.prob(sample) 36 | return prob 37 | 38 | if function_mode: 39 | sample_and_prob_func = tf.function(sample_and_prob_func) 40 | 41 | res = sample_and_prob_func(k).numpy() 42 | assert res.shape == sample_shape 43 | assert_allclose(res, mass) 44 | 45 | 46 | def test_discretized_sample_and_log_prob_batch(tensor_constant): 47 | batch_size = 3 48 | base_rate = tensor_constant(2.0) 49 | rate = base_rate + tf.range(batch_size, dtype=base_rate.dtype) 50 | base_dist = Gamma(tensor_constant(2.0), rate) 51 | k = 6 52 | discretized = DiscretizedDistribution(k, base_dist) 53 | mass = 1.0 / k 54 | sample_shape = 2 55 | sample = discretized.sample(sample_shape, seed=1) 56 | prob = discretized.prob(sample).numpy() 57 | assert prob.shape == (sample_shape, batch_size) 58 | assert_allclose(prob, mass) 59 | 60 | 61 | def test_discretized_batch_shape(tensor_constant): 62 | batch_size = 3 63 | base_rate = tensor_constant(2.0) 64 | rate = base_rate + tf.range(batch_size, dtype=base_rate.dtype) 65 | base_dist = Gamma(tensor_constant(2.0), rate) 66 | discretized = DiscretizedDistribution(2, base_dist) 67 | batch_shape = discretized.batch_shape_tensor() 68 | assert tuple(batch_shape.numpy()) == (batch_size,) 69 | -------------------------------------------------------------------------------- /test/distributions/tree/birthdeath/test_birth_death_contemporary_sampling.py: -------------------------------------------------------------------------------- 1 | from treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling import ( 2 | BirthDeathContemporarySampling, 3 | ) 4 | from treeflow.tree.io import parse_newick 5 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 6 | from numpy.testing import assert_allclose 7 | import tensorflow as tf 8 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 9 | 10 | 11 | def test_BirthDeathContemporarySampling_log_prob(): 12 | newick = "((((human:0.024003,(chimp:0.010772,bonobo:0.010772):0.013231):0.012035,gorilla:0.036038):0.033087000000000005,orangutan:0.069125):0.030456999999999998,siamang:0.099582);" 13 | tree = convert_tree_to_tensor(parse_newick(newick)) 14 | birth_diff_rate = tf.constant(1.0, dtype=DEFAULT_FLOAT_DTYPE_TF) 15 | relative_death_rate = tf.constant(0.5, dtype=DEFAULT_FLOAT_DTYPE_TF) 16 | expected = 1.2661341104158121 # From BEAST 2 BirthDeathGerhard08ModelTest 17 | 18 | dist = BirthDeathContemporarySampling( 19 | tree.taxon_count, birth_diff_rate, relative_death_rate 20 | ) 21 | res = dist.log_prob(tree) 22 | 23 | assert_allclose(res.numpy(), expected) 24 | 25 | 26 | def test_BirthDeathContemporarySampling_log_prob_vec(hello_newick_file): 27 | tree = convert_tree_to_tensor(parse_newick(hello_newick_file)) 28 | tree_b = tree.with_node_heights(tf.expand_dims(tree.node_heights, 0)) 29 | birth_diff_rate = tf.constant([1.0, 1.2], dtype=DEFAULT_FLOAT_DTYPE_TF) 30 | relative_death_rate = tf.constant([0.5, 0.3], dtype=DEFAULT_FLOAT_DTYPE_TF) 31 | 32 | dist = BirthDeathContemporarySampling( 33 | tree.taxon_count, birth_diff_rate, relative_death_rate 34 | ) 35 | res = dist.log_prob(tree_b) 36 | assert res.numpy().shape == birth_diff_rate.numpy().shape 37 | -------------------------------------------------------------------------------- /test/distributions/tree/birthdeath/test_yule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 4 | from treeflow.distributions.tree.birthdeath.yule import Yule 5 | from treeflow.tree.io import parse_newick 6 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 7 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 8 | from numpy.testing import assert_allclose 9 | 10 | 11 | def test_yule(): 12 | birth_rate = tf.constant(10.0, dtype=DEFAULT_FLOAT_DTYPE_TF) 13 | numpy_tree = parse_newick("((A:1.0,B:1.0):1.0,C:2.0);") 14 | tree = convert_tree_to_tensor(numpy_tree) 15 | dist = Yule(tree.taxon_count, birth_rate) 16 | res = dist.log_prob(tree) 17 | expected = calc_alt_yule_log_p(numpy_tree, birth_rate.numpy()) 18 | assert_allclose(res.numpy(), expected) 19 | 20 | 21 | # From BEAST 2: test.beast.evolution.speciation.YuleModelTest 22 | def calc_alt_yule_log_p(tree: NumpyRootedTree, birth_rate): 23 | n = tree.taxon_count 24 | log_p = (n - 1) * np.log(birth_rate) - birth_rate * tree.node_heights[-1] 25 | for height in tree.node_heights: 26 | log_p -= birth_rate * height 27 | return log_p 28 | -------------------------------------------------------------------------------- /test/distributions/tree/test_rooted_tree_distribution.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from treeflow.distributions.tree.rooted_tree_distribution import RootedTreeDistribution 3 | from treeflow.tree.rooted.tensorflow_rooted_tree import TensorflowRootedTree 4 | from tensorflow_probability.python.internal import reparameterization 5 | import pytest 6 | 7 | 8 | class DumbRootedTreeDistribution(RootedTreeDistribution): 9 | def __init__(self, taxon_count, name="DumbRootedTreeDistribution"): 10 | super().__init__( 11 | taxon_count=taxon_count, 12 | node_height_reparameterization_type=reparameterization.NOT_REPARAMETERIZED, 13 | sampling_time_reparameterization_type=reparameterization.NOT_REPARAMETERIZED, 14 | name=name, 15 | support_topology_batch_dims=True, 16 | ) 17 | 18 | def _parameter_properties(self, num_classes=None): 19 | return dict() 20 | 21 | def _sample_n( 22 | self, 23 | n, 24 | seed=None, 25 | ): 26 | event_shape = self.event_shape 27 | dtype = self.dtype 28 | 29 | shape_func = lambda event_shape: tf.concat([[n], event_shape], axis=0) 30 | 31 | return tf.nest.map_structure( 32 | lambda event_shape, dtype: tf.zeros(shape_func(event_shape), dtype), 33 | event_shape, 34 | dtype, 35 | ) 36 | 37 | 38 | @pytest.mark.parametrize("sample_shape", [(), 1, 3, (3,), (3, 2)]) 39 | def test_rooted_tree_distribution_sample(sample_shape): 40 | taxon_count = 4 41 | distribution_instance = DumbRootedTreeDistribution(taxon_count) 42 | samples = distribution_instance.sample(sample_shape) 43 | 44 | 45 | def test_rooted_tree_distribution_event_shape_tensor_fully_defined(): 46 | taxon_count = 4 47 | distribution_instance = DumbRootedTreeDistribution(taxon_count) 48 | event_shape = distribution_instance.event_shape_tensor() 49 | assert tuple(event_shape.node_heights.numpy()) == (taxon_count - 1,) 50 | assert tuple(event_shape.sampling_times.numpy()) == (taxon_count,) 51 | assert tuple(event_shape.topology.parent_indices.numpy()) == (2 * taxon_count - 2,) 52 | -------------------------------------------------------------------------------- /test/evolution/calibration/test_calibration.py: -------------------------------------------------------------------------------- 1 | from treeflow.evolution.calibration.calibration import MRCACalibration 2 | from tensorflow_probability.python.distributions import Normal 3 | from numpy.testing import assert_allclose 4 | 5 | 6 | def test_MRCACalibration_get_normal_sample_sd(): 7 | high = 6.4 8 | low = 3.2 9 | calibration = MRCACalibration(["a", "b"], (low, high)) 10 | prob = 0.99 11 | alpha = (1 - prob) / 2.0 12 | sd = calibration.get_normal_sd(prob) 13 | loc = calibration.get_normal_mean() 14 | res = Normal(loc, sd).cdf([low, high]).numpy() 15 | assert_allclose(res, [alpha, 1 - alpha], atol=1e-9) 16 | -------------------------------------------------------------------------------- /test/evolution/calibration/test_mrca.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from numpy.testing import assert_allclose 3 | from treeflow.tree.io import parse_newick 4 | from treeflow.evolution.calibration.mrca import get_mrca_index 5 | 6 | 7 | @pytest.mark.parametrize( 8 | ["taxa", "expected_mrca_height"], 9 | [ 10 | ( 11 | [ 12 | "T14_1.3684210526315788", 13 | "T18_1.789473684210526", 14 | "T10_0.9473684210526315", 15 | ], 16 | 1.79992, 17 | ), 18 | (["T6_0.5263157894736842", "T5_0.42105263157894735"], 2.05092), 19 | ], 20 | ) 21 | def test_get_mrca_index(tree_sim_newick_file, taxa, expected_mrca_height): 22 | tree = parse_newick(tree_sim_newick_file, remove_zero_edges=False) 23 | mrca_index = get_mrca_index(tree.topology, taxa) 24 | assert_allclose(tree.heights[mrca_index], expected_mrca_height) 25 | -------------------------------------------------------------------------------- /test/evolution/substitution/nucleotide/test_gtr.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from treeflow_test_helpers.substitution_helpers import ( 3 | EigenSubstitutionModelHelper, 4 | ) 5 | from treeflow.evolution.substitution.nucleotide.gtr import GTR 6 | from treeflow.evolution.substitution.nucleotide.hky import HKY 7 | import numpy as np 8 | from numpy.testing import assert_allclose 9 | 10 | 11 | @pytest.fixture 12 | def gtr_params(tensor_constant): 13 | return dict( 14 | frequencies=tensor_constant([0.21, 0.28, 0.27, 0.24]), 15 | rates=tensor_constant([0.2, 0.12, 0.17, 0.09, 0.24, 0.18]), 16 | ) 17 | 18 | 19 | class TestGTR(EigenSubstitutionModelHelper): 20 | ClassUnderTest = GTR 21 | 22 | def _init(self, gtr_params): 23 | self.params = gtr_params 24 | 25 | def test_eigendecomposition(self, gtr_params): 26 | self._init(gtr_params) 27 | super().test_eigendecomposition() 28 | 29 | def test_hky_special_case(self, tensor_constant, hky_params): 30 | kappa = hky_params["kappa"].numpy() 31 | rates = np.ones(6) 32 | rates[1] = kappa 33 | rates[4] = kappa 34 | rates = tensor_constant(rates / np.sum(rates)) 35 | gtr_q_norm = GTR().q_norm(frequencies=hky_params["frequencies"], rates=rates) 36 | hky_q_norm = HKY().q_norm(**hky_params) 37 | assert_allclose(gtr_q_norm.numpy(), hky_q_norm.numpy()) 38 | -------------------------------------------------------------------------------- /test/evolution/substitution/nucleotide/test_hky.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from treeflow.evolution.substitution.nucleotide.hky import ( 3 | HKY, 4 | pack_matrix, 5 | pack_matrix_transposed, 6 | ) 7 | from functools import reduce 8 | import tensorflow as tf 9 | from numpy.testing import assert_allclose 10 | from treeflow_test_helpers.substitution_helpers import ( 11 | EigenSubstitutionModelHelper, 12 | ) 13 | 14 | 15 | @pytest.mark.parametrize("batch_shape", [(1, 2), ()]) 16 | def test_pack_matrix(batch_shape): 17 | nrow = 3 18 | ncol = 2 19 | batch_size = reduce(lambda x, y: x * y, batch_shape, 1) 20 | elements = [ 21 | [ 22 | tf.reshape(tf.range(i + j, i + j + batch_size), batch_shape) 23 | for j in range(ncol) 24 | ] 25 | for i in range(nrow) 26 | ] 27 | res = pack_matrix(elements) 28 | for i in range(nrow): 29 | for j in range(ncol): 30 | element = res[..., i, j] 31 | assert element.shape == batch_shape 32 | assert_allclose(elements[i][j], element.numpy()) 33 | 34 | 35 | @pytest.mark.parametrize("batch_shape", [(1, 2), ()]) 36 | def test_pack_matrix_transpose(batch_shape): 37 | nrow = 3 38 | ncol = 2 39 | batch_size = reduce(lambda x, y: x * y, batch_shape, 1) 40 | elements = [ 41 | [ 42 | tf.reshape(tf.range(i + j, i + j + batch_size), batch_shape) 43 | for j in range(ncol) 44 | ] 45 | for i in range(nrow) 46 | ] 47 | res = pack_matrix_transposed(elements) 48 | for i in range(nrow): 49 | for j in range(ncol): 50 | element = res[..., j, i] 51 | assert element.shape == batch_shape 52 | assert_allclose(elements[i][j], element.numpy()) 53 | 54 | 55 | class TestHKY(EigenSubstitutionModelHelper): 56 | ClassUnderTest = HKY 57 | 58 | def _init(self, hky_params): 59 | self.params = hky_params 60 | 61 | def test_eigendecomposition(self, hky_params): 62 | self._init(hky_params) 63 | super().test_eigendecomposition() 64 | 65 | 66 | def test_hky_q_norm_vec(hky_params_vec): 67 | res = HKY().q_norm(**hky_params_vec) 68 | assert tuple(res.shape) == tuple(hky_params_vec["kappa"].shape) + (4, 4) 69 | 70 | 71 | def test_hky_eigen_vec(hky_params_vec): 72 | res = HKY().eigen(**hky_params_vec) 73 | 74 | batch_shape = tuple(hky_params_vec["kappa"].shape) 75 | assert tuple(res.eigenvectors.shape) == batch_shape + (4, 4) 76 | assert tuple(res.inverse_eigenvectors.shape) == batch_shape + (4, 4) 77 | assert tuple(res.eigenvalues.shape) == batch_shape + (4,) 78 | -------------------------------------------------------------------------------- /test/evolution/substitution/nucleotide/test_jc.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import tensorflow as tf 3 | from treeflow.evolution.substitution.nucleotide.jc import JC 4 | from treeflow_test_helpers.substitution_helpers import ( 5 | EigenSubstitutionModelHelper, 6 | ) 7 | 8 | 9 | class TestJC(EigenSubstitutionModelHelper): 10 | ClassUnderTest = JC 11 | params: tp.Mapping[str, tf.Tensor] = dict(frequencies=JC.frequencies()) 12 | -------------------------------------------------------------------------------- /test/evolution/substitution/test_eigendecomposition.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from treeflow.evolution.substitution.eigendecomposition import Eigendecomposition 3 | 4 | 5 | def test_eigendecomposition_add_inner_batch_dims(): 6 | n_states = 5 7 | batch_shape = (3, 2) 8 | eig = Eigendecomposition( 9 | eigenvectors=tf.zeros(batch_shape + (n_states, n_states)), 10 | inverse_eigenvectors=tf.zeros(batch_shape + (n_states, n_states)), 11 | eigenvalues=tf.zeros(batch_shape + (n_states,)), 12 | ) 13 | res = eig.add_inner_batch_dimensions(2) 14 | assert res.eigenvectors.numpy().shape == batch_shape + (1, 1) + (n_states, n_states) 15 | assert res.inverse_eigenvectors.numpy().shape == batch_shape + (1, 1) + ( 16 | n_states, 17 | n_states, 18 | ) 19 | assert res.eigenvalues.numpy().shape == batch_shape + (1, 1) + (n_states,) 20 | -------------------------------------------------------------------------------- /test/evolution/substitution/test_probabilities.py: -------------------------------------------------------------------------------- 1 | from cgitb import reset 2 | import pytest 3 | from numpy.testing import assert_allclose 4 | import tensorflow as tf 5 | from treeflow.evolution.substitution.nucleotide.hky import HKY 6 | from treeflow.evolution.substitution.nucleotide.jc import JC 7 | from treeflow.evolution.substitution.probabilities import ( 8 | get_transition_probabilities_eigen, 9 | get_transition_probabilities_tree, 10 | ) 11 | 12 | 13 | _branch_lengths = [0.1, [1.2, 3.2, 0.1], [[0.05, 1.3], [0.8, 0.3]]] 14 | 15 | 16 | @pytest.fixture(params=_branch_lengths) 17 | def branch_lengths(tensor_constant, request): 18 | return tensor_constant(request.param) 19 | 20 | 21 | def test_get_transition_probabilities_eigen_hky_rowsum( 22 | branch_lengths, 23 | hky_params, 24 | ): 25 | eigen_batch = ( 26 | HKY().eigen(**hky_params).add_inner_batch_dimensions(branch_lengths.shape.rank) 27 | ) 28 | 29 | res = get_transition_probabilities_eigen(eigen_batch, branch_lengths) 30 | row_sums = tf.reduce_sum(res, axis=-1) 31 | assert_allclose(1.0, row_sums) 32 | 33 | 34 | def test_get_transition_probabilities_tree_hky_vec( 35 | hky_params_vec, hello_tensor_tree, tensor_constant 36 | ): 37 | rate_categories = tensor_constant([0.1, 0.3, 0.6, 1.0, 1.5]) 38 | category_count = rate_categories.shape[0] 39 | unrooted_tree = hello_tensor_tree.get_unrooted_tree() 40 | res = get_transition_probabilities_tree( 41 | unrooted_tree, 42 | HKY(), 43 | **hky_params_vec, 44 | rate_categories=rate_categories, 45 | batch_rank=-1, 46 | inner_batch_rank=1 47 | ) 48 | assert tuple(res.branch_lengths.shape) == (category_count,) + tuple( 49 | hky_params_vec["kappa"].shape 50 | ) + ( 51 | 4, 52 | 4, 53 | ) 54 | -------------------------------------------------------------------------------- /test/evolution/test_seqio.py: -------------------------------------------------------------------------------- 1 | from treeflow.evolution.seqio import Alignment 2 | 3 | 4 | def test_seqio_parse_fasta(hello_fasta_file): 5 | alignment = Alignment(hello_fasta_file) 6 | expected_keys = {"mars", "saturn", "jupiter"} 7 | expected_len = 31 8 | assert set(alignment._sequence_mapping.keys()) == expected_keys 9 | for key in expected_keys: 10 | assert (len(alignment._sequence_mapping[key])) == expected_len 11 | -------------------------------------------------------------------------------- /test/fixtures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/test/fixtures/__init__.py -------------------------------------------------------------------------------- /test/fixtures/cli_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture 5 | def trace_output_path(tmp_path): 6 | return tmp_path / "trace.pickle" 7 | 8 | 9 | @pytest.fixture 10 | def samples_output_path(tmp_path): 11 | return tmp_path / "approx-samples.csv" 12 | 13 | 14 | @pytest.fixture 15 | def tree_samples_output_path(tmp_path): 16 | return tmp_path / "approx-tree-samples.nexus" 17 | 18 | 19 | @pytest.fixture 20 | def actual_model_file(test_data_dir): 21 | return str(test_data_dir / "model.yaml") 22 | 23 | 24 | @pytest.fixture(params=[None, "model.yaml"]) 25 | def model_file(request, test_data_dir): 26 | filename = request.param 27 | if filename is None: 28 | return None 29 | else: 30 | return str(test_data_dir / filename) 31 | -------------------------------------------------------------------------------- /test/fixtures/data_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from treeflow.tree.io import parse_newick 3 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 4 | from treeflow.evolution.seqio import Alignment 5 | 6 | _HELLO_NEWICK = "hello.nwk" 7 | 8 | 9 | @pytest.fixture 10 | def hello_newick_file(test_data_dir): 11 | return str(test_data_dir / _HELLO_NEWICK) 12 | 13 | 14 | _HELLO_FASTA = "hello.fasta" 15 | 16 | 17 | @pytest.fixture 18 | def hello_fasta_file(test_data_dir): 19 | return str(test_data_dir / _HELLO_FASTA) 20 | 21 | 22 | @pytest.fixture 23 | def hello_tensor_tree(hello_newick_file): 24 | numpy_tree = parse_newick(hello_newick_file) 25 | return convert_tree_to_tensor(numpy_tree) 26 | 27 | 28 | @pytest.fixture 29 | def hello_alignment(hello_fasta_file): 30 | return Alignment(hello_fasta_file) 31 | 32 | 33 | _WNV_NEWICK = "wnv.nwk" 34 | 35 | 36 | @pytest.fixture 37 | def wnv_newick_file(test_data_dir): 38 | return str(test_data_dir / _WNV_NEWICK) 39 | 40 | 41 | _WNV_FASTA = "wnv.fasta" 42 | 43 | 44 | @pytest.fixture 45 | def wnv_fasta_file(test_data_dir): 46 | return str(test_data_dir / _WNV_FASTA) 47 | 48 | 49 | @pytest.fixture( 50 | params=[ 51 | ( 52 | _HELLO_NEWICK, 53 | _HELLO_FASTA, 54 | False, 55 | ), 56 | ( 57 | _WNV_NEWICK, 58 | _WNV_FASTA, 59 | True, 60 | ), 61 | ] 62 | ) 63 | def newick_fasta_file_dated(test_data_dir, request): 64 | newick_file, fasta_file, dated = request.param 65 | return str(test_data_dir / newick_file), str(test_data_dir / fasta_file), dated 66 | 67 | 68 | @pytest.fixture 69 | def newick_file_dated(newick_fasta_file_dated): 70 | newick_file, _, dated = newick_fasta_file_dated 71 | return newick_file, dated 72 | 73 | 74 | @pytest.fixture 75 | def tree_sim_newick_file(test_data_dir): 76 | return str(test_data_dir / "tree-sim.newick") 77 | 78 | 79 | @pytest.fixture 80 | def beast_test_case_newick_file(test_data_dir): 81 | return str(test_data_dir / "beast-test-case.nwk") 82 | 83 | 84 | @pytest.fixture 85 | def beast_test_case_tree(beast_test_case_newick_file): 86 | numpy_tree = parse_newick(beast_test_case_newick_file) 87 | return convert_tree_to_tensor(numpy_tree) 88 | 89 | 90 | @pytest.fixture 91 | def beast_test_case_fasta_file(test_data_dir): 92 | return str(test_data_dir / "beast-test-case.fasta") 93 | -------------------------------------------------------------------------------- /test/fixtures/ratio_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from treeflow_test_helpers.ratio_helpers import RatioTestData 4 | 5 | sampling_times_flat = [ 6 | np.array(x) for x in [[0.0, 0.0, 0.0, 0.0, 0.0], [0.1, 0.2, 0.0, 0.3, 0.2]] 7 | ] 8 | sampling_times = sampling_times_flat + [np.stack(sampling_times_flat)] 9 | anchor_heights_flat = [ 10 | np.array(x) for x in [[0.0, 0.0, 0.0, 0.0], [0.2, 0.3, 0.3, 0.3]] 11 | ] 12 | anchor_heights = anchor_heights_flat + [np.stack(anchor_heights_flat)] 13 | heights_flat = [np.array(x) for x in [[0.2, 0.5, 0.8, 1.6], [0.6, 0.75, 1.2, 1.5]]] 14 | heights = heights_flat + [np.stack(heights_flat)] 15 | ratios_flat = [ 16 | np.array(x) 17 | for x in [ 18 | [0.25, 0.625, 0.5, 1.6], 19 | [0.4, 0.5, 0.75, 1.2], 20 | ] 21 | ] 22 | ratios = ratios_flat + [np.stack(ratios_flat)] 23 | parent_indices = np.array([5, 5, 6, 6, 8, 7, 7, 8]) 24 | taxon_count = 5 25 | node_parent_indices = parent_indices[taxon_count:] - taxon_count 26 | preorder_node_indices = np.array([8, 7, 5, 6]) - 5 27 | 28 | 29 | @pytest.fixture( 30 | params=[ 31 | RatioTestData( 32 | heights=heights_element, 33 | parent_indices=parent_indices, 34 | node_parent_indices=node_parent_indices, 35 | preorder_node_indices=preorder_node_indices, 36 | ratios=ratios_element, 37 | anchor_heights=anchor_heights_element, 38 | sampling_times=sampling_times_element, 39 | ) 40 | for heights_element, ratios_element, anchor_heights_element, sampling_times_element in zip( 41 | heights, ratios, anchor_heights, sampling_times 42 | ) 43 | ] 44 | ) 45 | def ratio_test_data(request): 46 | return request.param 47 | 48 | 49 | @pytest.fixture 50 | def flat_ratio_test_data(): 51 | return RatioTestData( 52 | heights=heights[0], 53 | parent_indices=parent_indices, 54 | node_parent_indices=node_parent_indices, 55 | preorder_node_indices=preorder_node_indices, 56 | ratios=ratios[0], 57 | anchor_heights=anchor_heights[0], 58 | sampling_times=sampling_times[0], 59 | ) 60 | -------------------------------------------------------------------------------- /test/fixtures/substitution_fixtures.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import pytest 3 | import tensorflow as tf 4 | 5 | 6 | @pytest.fixture 7 | def hky_params(tensor_constant): 8 | return dict( 9 | frequencies=tensor_constant([0.23, 0.27, 0.24, 0.26]), 10 | kappa=tensor_constant(2.0), 11 | ) 12 | 13 | 14 | @pytest.fixture 15 | def hky_params_vec(hky_params, hello_tensor_tree): 16 | kappa = hky_params["kappa"] 17 | kappa_vec = kappa + tf.cast( 18 | tf.range(hello_tensor_tree.branch_lengths.shape), kappa.dtype 19 | ) 20 | frequencies = hky_params["frequencies"] 21 | frequencies_b = tf.broadcast_to(frequencies, kappa_vec.shape + frequencies.shape) 22 | return dict(frequencies=frequencies_b, kappa=kappa_vec) 23 | 24 | 25 | @pytest.fixture 26 | def hello_hky_log_likelihood(): 27 | return -88.86355638556158 28 | -------------------------------------------------------------------------------- /test/fixtures/tree_fixtures.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from treeflow_test_helpers.tree_helpers import TreeTestData 4 | 5 | 6 | taxon_count = 3 7 | branch_lengths_flat = [ 8 | np.array(x) for x in [[0.3, 0.4, 1.2, 0.7], [0.9, 0.2, 2.3, 1.4]] 9 | ] 10 | sampling_times_flat = [ 11 | np.array(x) 12 | for x in [ 13 | [ 14 | 0.2, 15 | 0.1, 16 | 0.0, 17 | ], 18 | [ 19 | 0.0, 20 | 0.7, 21 | 0.0, 22 | ], 23 | ] 24 | ] 25 | node_heights_flat = [np.array(x) for x in [[0.5, 1.2], [0.9, 2.3]]] 26 | parent_indices_single = np.array([3, 3, 4, 4]) 27 | branch_lengths_stacked = np.stack(branch_lengths_flat) 28 | node_heights_stacked = np.stack(node_heights_flat) 29 | sampling_times_stacked = np.stack(sampling_times_flat) 30 | parent_indices_stacked = np.stack([parent_indices_single, parent_indices_single]) 31 | 32 | node_heights = node_heights_flat + ([node_heights_stacked] * 2) 33 | sampling_times = sampling_times_flat + ([sampling_times_stacked] * 2) 34 | branch_lengths = branch_lengths_flat + ([branch_lengths_stacked] * 2) 35 | parent_indices = ([parent_indices_single] * 3) + [np.stack([parent_indices_single] * 2)] 36 | 37 | 38 | @pytest.fixture( 39 | params=[ 40 | TreeTestData(*args) 41 | for args in zip(parent_indices, node_heights, sampling_times, branch_lengths) 42 | ] 43 | ) 44 | def tree_test_data(request): 45 | return request.param 46 | 47 | 48 | _flat_tree_test_data = TreeTestData( 49 | parent_indices[0], node_heights[0], sampling_times[0], branch_lengths[0] 50 | ) 51 | 52 | 53 | @pytest.fixture() 54 | def flat_tree_test_data(): 55 | return _flat_tree_test_data 56 | -------------------------------------------------------------------------------- /test/helpers/treeflow_test_helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/test/helpers/treeflow_test_helpers/__init__.py -------------------------------------------------------------------------------- /test/helpers/treeflow_test_helpers/optimization_helpers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import treeflow 3 | 4 | 5 | def vars(prefix=""): 6 | return dict( 7 | a=tf.Variable( 8 | tf.convert_to_tensor(1.0, dtype=treeflow.DEFAULT_FLOAT_DTYPE_TF), 9 | name=f"{prefix}a", 10 | ), 11 | b=tf.Variable( 12 | tf.convert_to_tensor([1.2, 3.2], dtype=treeflow.DEFAULT_FLOAT_DTYPE_TF), 13 | name=f"{prefix}b", 14 | ), 15 | ) 16 | 17 | 18 | obs = tf.convert_to_tensor([0.8, 1.2], dtype=treeflow.DEFAULT_FLOAT_DTYPE_TF) 19 | 20 | 21 | def obj(vars): 22 | return lambda: tf.math.square(vars["a"]) + tf.math.reduce_sum( 23 | tf.math.square(vars["a"] - vars["b"]) + tf.math.square(vars["b"] - obs) 24 | ) 25 | 26 | 27 | def optimizer_builder(): 28 | return tf.optimizers.SGD(learning_rate=1e-2) 29 | 30 | 31 | __all__ = [vars.__name__, obj.__name__, optimizer_builder.__name__] 32 | -------------------------------------------------------------------------------- /test/helpers/treeflow_test_helpers/ratio_helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 4 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 5 | from treeflow.tree.rooted.tensorflow_rooted_tree import ( 6 | convert_tree_to_tensor, 7 | TensorflowRootedTree, 8 | ) 9 | from treeflow.tree.topology.tensorflow_tree_topology import ( 10 | TensorflowTreeTopology, 11 | numpy_topology_to_tensor, 12 | ) 13 | 14 | RatioTestData = namedtuple( 15 | "RatioTestData", 16 | [ 17 | "heights", 18 | "node_parent_indices", 19 | "parent_indices", 20 | "preorder_node_indices", 21 | "ratios", 22 | "anchor_heights", 23 | "sampling_times", 24 | ], 25 | ) 26 | 27 | 28 | def numpy_topology_from_ratio_test_data( 29 | ratio_test_data: RatioTestData, 30 | ) -> NumpyTreeTopology: 31 | return NumpyTreeTopology(parent_indices=ratio_test_data.parent_indices) 32 | 33 | 34 | def topology_from_ratio_test_data( 35 | ratio_test_data: RatioTestData, 36 | ) -> TensorflowTreeTopology: 37 | return numpy_topology_to_tensor( 38 | numpy_topology_from_ratio_test_data(ratio_test_data) 39 | ) 40 | 41 | 42 | def numpy_tree_from_ratio_test_data(ratio_test_data: RatioTestData) -> NumpyRootedTree: 43 | return NumpyRootedTree( 44 | sampling_times=ratio_test_data.sampling_times, 45 | node_heights=ratio_test_data.heights, 46 | topology=numpy_topology_from_ratio_test_data(ratio_test_data), 47 | ) 48 | 49 | 50 | def tree_from_ratio_test_data(ratio_test_data: RatioTestData) -> TensorflowRootedTree: 51 | return convert_tree_to_tensor( 52 | numpy_tree_from_ratio_test_data(ratio_test_data), 53 | ) 54 | 55 | 56 | __all__ = [ 57 | RatioTestData.__name__, 58 | numpy_topology_from_ratio_test_data.__name__, 59 | topology_from_ratio_test_data.__name__, 60 | numpy_tree_from_ratio_test_data.__name__, 61 | ] 62 | -------------------------------------------------------------------------------- /test/helpers/treeflow_test_helpers/substitution_helpers.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import typing as tp 3 | import tensorflow as tf 4 | from treeflow.evolution.substitution.base_substitution_model import ( 5 | EigendecompositionSubstitutionModel, 6 | ) 7 | from numpy.testing import assert_allclose 8 | 9 | 10 | class SubstitutionModelHelper: 11 | 12 | ClassUnderTest: tp.Type[EigendecompositionSubstitutionModel] 13 | params: tp.Mapping[str, tf.Tensor] 14 | 15 | # TODO 16 | 17 | 18 | class EigenSubstitutionModelHelper(SubstitutionModelHelper): 19 | def test_eigendecomposition(self): 20 | model = self.ClassUnderTest() 21 | res = model.eigen(**self.params) 22 | assert res.eigenvalues.shape == (4,) 23 | assert res.eigenvectors.shape == (4, 4) 24 | assert res.inverse_eigenvectors.shape == (4, 4) 25 | 26 | q = model.q_norm(**self.params) 27 | q_res = ( 28 | res.eigenvectors 29 | @ tf.linalg.diag(res.eigenvalues) 30 | @ res.inverse_eigenvectors 31 | ) 32 | assert_allclose(q.numpy(), q_res.numpy()) 33 | -------------------------------------------------------------------------------- /test/helpers/treeflow_test_helpers/tree_helpers.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 3 | 4 | TreeTestData = namedtuple( 5 | "TreeTestData", 6 | ["parent_indices", "node_heights", "sampling_times", "branch_lengths"], 7 | ) 8 | 9 | from treeflow.tree.rooted.tensorflow_rooted_tree import ( 10 | TensorflowRootedTree, 11 | convert_tree_to_tensor, 12 | ) 13 | 14 | 15 | def data_to_tensor_tree(tree_test_data: TreeTestData) -> TensorflowRootedTree: 16 | numpy_tree = NumpyRootedTree( 17 | node_heights=tree_test_data.node_heights, 18 | sampling_times=tree_test_data.sampling_times, 19 | parent_indices=tree_test_data.parent_indices, 20 | ) 21 | tf_tree = convert_tree_to_tensor(numpy_tree) 22 | return tf_tree 23 | -------------------------------------------------------------------------------- /test/model/approximation/test_cascading_flows.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | from treeflow.tree.rooted.tensorflow_rooted_tree import TensorflowRootedTree 5 | from treeflow.model.approximation.cascading_flows import ( 6 | get_cascading_flows_tree_approximation, 7 | ) 8 | 9 | 10 | @pytest.mark.parametrize("function_mode", [True, False]) 11 | def test_get_cascading_flows_tree_approximation( 12 | hello_tensor_tree: TensorflowRootedTree, function_mode: bool 13 | ): 14 | approx = get_cascading_flows_tree_approximation(hello_tensor_tree) 15 | trainable_variables = approx.trainable_variables 16 | sample_shape = (4,) 17 | 18 | def test_func(): 19 | with tf.GradientTape() as t: 20 | for variable in trainable_variables: 21 | t.watch(variable) 22 | sample = approx.sample(sample_shape) 23 | log_prob = approx.log_prob(sample) 24 | grads = t.gradient(log_prob, trainable_variables) 25 | return sample, log_prob, grads 26 | 27 | if function_mode: 28 | test_func = tf.function(test_func) 29 | 30 | sample, log_prob, grads = test_func() 31 | assert isinstance(sample, TensorflowRootedTree) 32 | assert log_prob.numpy().shape == sample_shape 33 | assert all(np.isfinite(log_prob.numpy())) 34 | for grad, variable in zip(grads, trainable_variables): 35 | assert grad is not None, f"Must have gradient wrt {variable.name}" 36 | -------------------------------------------------------------------------------- /test/model/test_ml.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorflow_probability.python.distributions as tfd 5 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 6 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 7 | from treeflow.tree.io import parse_newick 8 | from treeflow.distributions.tree.birthdeath.yule import Yule 9 | from treeflow.model.ml import fit_fixed_topology_maximum_likelihood_sgd 10 | 11 | 12 | _constant = lambda x: tf.constant(x, dtype=DEFAULT_FLOAT_DTYPE_TF) 13 | 14 | 15 | @pytest.mark.parametrize("with_init", [True, False]) 16 | def test_ml_tree_yule(tensor_constant, hello_newick_file, with_init): 17 | tree = convert_tree_to_tensor(parse_newick(hello_newick_file)) 18 | tree_name = "tree_dist_name" 19 | model = tfd.JointDistributionNamed( 20 | { 21 | "rates": tfd.Sample( 22 | tfd.LogNormal(_constant(0.0), _constant(1.0)), tree.branch_lengths.shape 23 | ), 24 | "birth_rate": tfd.LogNormal(_constant(1.0), _constant(1.5)), 25 | tree_name: lambda birth_rate: Yule( 26 | tree.taxon_count, birth_rate, name=tree_name 27 | ), 28 | "a": lambda tree_dist_name, rates: tfd.Normal( 29 | tf.reduce_sum(tree_dist_name.branch_lengths * rates, axis=-1), 30 | _constant(1.0), 31 | ), 32 | } 33 | ) 34 | obs = _constant(10.0) 35 | pinned = model.experimental_pin(a=obs) 36 | 37 | if with_init: 38 | init = dict(tree=tree, birth_rate=_constant(2.0)) 39 | else: 40 | init = None 41 | 42 | res, trace, bijector = fit_fixed_topology_maximum_likelihood_sgd( 43 | pinned, topologies={tree_name: tree.topology}, num_steps=30, init=init 44 | ) 45 | assert all(np.isfinite(trace.log_likelihood.numpy())) 46 | variable_sample = pinned.sample_unpinned() 47 | tf.nest.assert_same_structure(res, variable_sample) 48 | tf.nest.assert_same_structure(trace.parameters, variable_sample) 49 | -------------------------------------------------------------------------------- /test/tf_util/test_vectorize.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import attr 3 | import tensorflow as tf 4 | from treeflow.tf_util.vectorize import broadcast_structure, vectorize_over_batch_dims 5 | from numpy.testing import assert_allclose 6 | 7 | 8 | @attr.attrs(auto_attribs=True) 9 | class InnerContainer: 10 | a: tf.Tensor 11 | b: tf.Tensor 12 | 13 | 14 | @attr.attrs(auto_attribs=True) 15 | class Container: 16 | inner: InnerContainer 17 | c: tf.Tensor 18 | 19 | 20 | def test_broadcast_structure(): 21 | a = tf.reshape(tf.range(9), [1, 3, 3]) 22 | b = tf.reshape(tf.range(18), [3, 2, 3]) 23 | c = tf.reshape(tf.range(2), [2, 1]) 24 | 25 | structure = Container(inner=InnerContainer(a=a, b=b), c=c) 26 | event_shape = Container( 27 | inner=InnerContainer( 28 | a=tf.convert_to_tensor([3]), b=tf.convert_to_tensor([2, 3]) 29 | ), 30 | c=tf.convert_to_tensor((), dtype=tf.int32), 31 | ) 32 | 33 | batch_shape = (2, 3) 34 | res = broadcast_structure(structure, event_shape, batch_shape) 35 | assert res.inner.a.shape == (2, 3, 3) 36 | assert res.inner.b.shape == (2, 3, 2, 3) 37 | assert res.c.shape == (2, 3) 38 | 39 | 40 | @pytest.mark.parametrize("function_mode", [False, True]) 41 | @pytest.mark.parametrize("vectorized_map", [False, True]) 42 | def test_vectorize_over_batch_dims_scalar(vectorized_map, function_mode): 43 | 44 | a = tf.reshape(tf.range(36), [3, 2, 3, 2]) 45 | b = tf.reshape(tf.range(36, 72), [3, 2, 2, 3]) 46 | c = tf.reshape(tf.range(72, 90), [3, 2, 3]) 47 | 48 | def func(container): 49 | return tf.reduce_sum( 50 | tf.matmul(container.inner.a, container.inner.b) * container.c 51 | ) 52 | 53 | structure = Container(inner=InnerContainer(a=a, b=b), c=c) 54 | event_shape = tf.nest.map_structure(lambda x: tf.shape(x)[2:], structure) 55 | batch_shape = [3, 2] 56 | 57 | def outer_func(arg): 58 | return vectorize_over_batch_dims( 59 | func, 60 | structure, 61 | event_shape, 62 | batch_shape, 63 | vectorized_map=vectorized_map, 64 | fn_output_signature=tf.int32, 65 | ) 66 | 67 | if function_mode: 68 | outer_func = tf.function(outer_func) 69 | 70 | res = outer_func(structure) 71 | 72 | inner_expected = tf.reduce_sum( 73 | tf.expand_dims(a, -1) * tf.expand_dims(b, -3), axis=-2 74 | ) 75 | expected = tf.reduce_sum( 76 | tf.reduce_sum(inner_expected * tf.expand_dims(c, -2), axis=-1), axis=-1 77 | ) 78 | assert_allclose(res.numpy(), expected.numpy()) 79 | 80 | 81 | def test_vectorize_over_batch_dims_structure(): 82 | pass # TODO 83 | -------------------------------------------------------------------------------- /test/traversal/test_get_anchor_heights.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_allclose 2 | from treeflow_test_helpers.ratio_helpers import ( 3 | RatioTestData, 4 | numpy_tree_from_ratio_test_data, 5 | ) 6 | from treeflow.traversal.anchor_heights import ( 7 | get_anchor_heights, 8 | get_anchor_heights_tensor, 9 | ) 10 | from treeflow.tree.rooted.tensorflow_rooted_tree import convert_tree_to_tensor 11 | 12 | 13 | def test_get_anchor_heights(ratio_test_data: RatioTestData): 14 | tree = numpy_tree_from_ratio_test_data(ratio_test_data) 15 | res = get_anchor_heights(tree) 16 | assert_allclose(res, ratio_test_data.anchor_heights) 17 | 18 | 19 | def test_get_anchor_heights_tensor(ratio_test_data: RatioTestData): 20 | tree = convert_tree_to_tensor(numpy_tree_from_ratio_test_data(ratio_test_data)) 21 | res = get_anchor_heights_tensor(tree.topology, tree.sampling_times) 22 | assert_allclose(res.numpy(), ratio_test_data.anchor_heights) 23 | -------------------------------------------------------------------------------- /test/traversal/test_phylo_likelihood.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from numpy.testing import assert_allclose 3 | import tensorflow as tf 4 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 5 | from treeflow.evolution.substitution.nucleotide.hky import HKY 6 | from treeflow.evolution.substitution.probabilities import ( 7 | get_transition_probabilities_eigen, 8 | ) 9 | from treeflow.evolution.seqio import Alignment 10 | from treeflow.traversal.phylo_likelihood import phylogenetic_likelihood 11 | from treeflow.tree.rooted.tensorflow_rooted_tree import TensorflowRootedTree 12 | 13 | 14 | @pytest.mark.parametrize("function_mode", [True, False]) 15 | def test_phylo_likelihood_hky_beast( 16 | hello_tensor_tree: TensorflowRootedTree, 17 | hello_alignment: Alignment, 18 | function_mode: bool, 19 | hky_params, 20 | hello_hky_log_likelihood: float, 21 | ): 22 | subst_model = HKY() 23 | eigen = subst_model.eigen(**hky_params) 24 | probs = tf.expand_dims( 25 | get_transition_probabilities_eigen(eigen, hello_tensor_tree.branch_lengths), 0 26 | ) 27 | encoded_sequences = hello_alignment.get_encoded_sequence_tensor( 28 | hello_tensor_tree.taxon_set 29 | ) 30 | if function_mode: 31 | func = tf.function(phylogenetic_likelihood) 32 | else: 33 | func = phylogenetic_likelihood 34 | site_partials = func( 35 | encoded_sequences, 36 | probs, 37 | hky_params["frequencies"], 38 | hello_tensor_tree.topology.postorder_node_indices, 39 | hello_tensor_tree.topology.node_child_indices, 40 | batch_shape=tf.shape(encoded_sequences)[:1], 41 | ) 42 | res = tf.reduce_sum(tf.math.log(site_partials)) 43 | expected = hello_hky_log_likelihood 44 | assert_allclose(res.numpy(), expected) 45 | -------------------------------------------------------------------------------- /test/traversal/test_preorder.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from numpy.testing import assert_allclose 3 | import tensorflow as tf 4 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 5 | from treeflow_test_helpers.ratio_helpers import ( 6 | topology_from_ratio_test_data, 7 | RatioTestData, 8 | ) 9 | from treeflow.traversal.preorder import preorder_traversal 10 | from treeflow.traversal.ratio_transform import move_outside_axis_to_inside 11 | from tensorflow_probability.python.internal import distribution_util 12 | 13 | 14 | def c(x): 15 | return tf.constant(x, dtype=DEFAULT_FLOAT_DTYPE_TF) 16 | 17 | 18 | def move_inside_axis_to_outside(x): 19 | return distribution_util.move_dimension(x, -1, 0) 20 | 21 | 22 | def ratios_to_node_heights_traversal(topology, ratios, anchor_heights): 23 | input = ( 24 | move_inside_axis_to_outside(ratios), 25 | move_inside_axis_to_outside(anchor_heights), 26 | ) 27 | 28 | def mapping(parent_height, input): 29 | ratio, anchor_height = input 30 | return (parent_height - anchor_height) * ratio + anchor_height 31 | 32 | init = input[0][-1] + input[1][-1] 33 | 34 | traversal_res = preorder_traversal(topology, mapping, input, init) 35 | return move_outside_axis_to_inside(traversal_res) 36 | 37 | 38 | @pytest.mark.parametrize("function_mode", [True, False]) 39 | def test_preorder_traversal_ratio_transform( 40 | ratio_test_data: RatioTestData, function_mode: bool 41 | ): 42 | topology = topology_from_ratio_test_data(ratio_test_data) 43 | ratios = c(ratio_test_data.ratios) 44 | anchor_heights = c(ratio_test_data.anchor_heights) 45 | 46 | if function_mode: 47 | func = tf.function(ratios_to_node_heights_traversal) 48 | else: 49 | func = ratios_to_node_heights_traversal 50 | res = func(topology, ratios, anchor_heights) 51 | assert_allclose(res.numpy(), ratio_test_data.heights) 52 | -------------------------------------------------------------------------------- /test/traversal/test_ratio_transform.py: -------------------------------------------------------------------------------- 1 | from numpy.testing import assert_allclose 2 | from treeflow.traversal.ratio_transform import ratios_to_node_heights 3 | import tensorflow as tf 4 | 5 | 6 | def test_ratios_to_node_heights(ratio_test_data): 7 | res = ratios_to_node_heights( 8 | tf.constant(ratio_test_data.preorder_node_indices, dtype=tf.int32), 9 | tf.constant(ratio_test_data.node_parent_indices, dtype=tf.int32), 10 | ratio_test_data.ratios, 11 | ratio_test_data.anchor_heights, 12 | ) 13 | 14 | assert_allclose(res, ratio_test_data.heights) 15 | -------------------------------------------------------------------------------- /test/tree/rooted/test_numpy_rooted_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from treeflow_test_helpers.tree_helpers import TreeTestData 4 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 5 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 6 | from numpy.testing import assert_allclose 7 | 8 | 9 | def test_numpy_tree_get_branch_lengths(tree_test_data: TreeTestData): 10 | expected_branch_lengths = tree_test_data.branch_lengths 11 | tree = NumpyRootedTree( 12 | sampling_times=tree_test_data.sampling_times, 13 | node_heights=tree_test_data.node_heights, 14 | parent_indices=tree_test_data.parent_indices, 15 | ) 16 | branch_lengths = tree.branch_lengths 17 | assert_allclose(branch_lengths, expected_branch_lengths) 18 | -------------------------------------------------------------------------------- /test/tree/rooted/test_tensorflow_rooted_tree.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 4 | from treeflow.tree.rooted.tensorflow_rooted_tree import ( 5 | TensorflowRootedTree, 6 | convert_tree_to_tensor, 7 | ) 8 | import tensorflow as tf 9 | from numpy.testing import assert_allclose 10 | from treeflow_test_helpers.tree_helpers import TreeTestData, data_to_tensor_tree 11 | 12 | 13 | def test_TensorflowRootedTree_from_numpy(tree_test_data: TreeTestData): 14 | tf_tree = data_to_tensor_tree(tree_test_data) 15 | assert_allclose(tf_tree.sampling_times.numpy(), tree_test_data.sampling_times) 16 | assert_allclose(tf_tree.node_heights.numpy(), tree_test_data.node_heights) 17 | assert_allclose( 18 | tf_tree.topology.parent_indices.numpy(), tree_test_data.parent_indices 19 | ) 20 | 21 | 22 | def test_TensorflowRootedTree_heights(tree_test_data: TreeTestData): 23 | tf_tree = data_to_tensor_tree(tree_test_data) 24 | heights_res = tf_tree.heights 25 | expected_heights = np.concatenate( 26 | (tree_test_data.sampling_times, tree_test_data.node_heights), axis=-1 27 | ) 28 | assert_allclose(heights_res.numpy(), expected_heights) 29 | 30 | 31 | @pytest.mark.parametrize("function_mode", [True, False]) 32 | def test_TensorflowRootedTree_get_branch_lengths( 33 | function_mode, tree_test_data: TreeTestData 34 | ): 35 | """Also tests composite tensor functionality""" 36 | tf_tree = data_to_tensor_tree(tree_test_data) 37 | 38 | def blen_func(tree: TensorflowRootedTree): 39 | return tree.branch_lengths 40 | 41 | if function_mode: 42 | blen_func = tf.function(blen_func) 43 | 44 | blen_result = blen_func(tf_tree) 45 | assert_allclose(tree_test_data.branch_lengths, blen_result.numpy()) 46 | -------------------------------------------------------------------------------- /test/tree/test_taxon_set.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from treeflow.tree.taxon_set import DictTaxonSet, TaxonSet, TupleTaxonSet 3 | 4 | taxa = {"a", "b", "c"} 5 | 6 | 7 | def taxon_set_checks(TaxonSetClass: tp.Callable[[tp.Iterable[str]], TaxonSet]): 8 | taxon_set = TaxonSetClass(taxa) 9 | assert set(taxon_set) == taxa 10 | assert "a" in taxon_set 11 | assert not ("d" in taxon_set) 12 | assert len(taxon_set) == 3 13 | 14 | 15 | def test_dict_taxon_set(): 16 | taxon_set_checks(DictTaxonSet) 17 | 18 | 19 | def test_tuple_taxon_set(): 20 | taxon_set_checks(TupleTaxonSet) 21 | 22 | 23 | def test_taxon_set_conversion(): 24 | conversion_taxon_set = lambda taxa: TupleTaxonSet(DictTaxonSet(taxa)) 25 | taxon_set_checks(conversion_taxon_set) 26 | -------------------------------------------------------------------------------- /test/tree/topology/test_numpy_topology_operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from treeflow.tree.topology.numpy_topology_operations import ( 4 | get_child_indices, 5 | get_preorder_indices, 6 | ) 7 | from numpy.testing import assert_equal 8 | 9 | flat_parent_indices = [np.array(x) for x in [[4, 4, 5, 5, 6, 6], [4, 4, 5, 6, 5, 6]]] 10 | taxon_count = (flat_parent_indices[0].shape[-1] + 2) // 2 11 | leaf_child_indices = [[-1, -1]] * taxon_count 12 | flat_child_indices = [ 13 | np.array(x) 14 | for x in [ 15 | leaf_child_indices + [[0, 1], [2, 3], [4, 5]], 16 | leaf_child_indices + [[0, 1], [2, 4], [3, 5]], 17 | ] 18 | ] 19 | 20 | parent_indices = flat_parent_indices + [np.stack(flat_parent_indices)] 21 | child_indices = flat_child_indices + [np.stack(flat_child_indices)] 22 | 23 | 24 | @pytest.mark.parametrize( 25 | ["parent_indices", "expected_child_indices"], 26 | zip(parent_indices, child_indices), 27 | ) 28 | def test_get_child_indices( 29 | parent_indices: np.ndarray, expected_child_indices: np.ndarray 30 | ): 31 | child_indices = get_child_indices(parent_indices) 32 | assert_equal(child_indices, expected_child_indices) 33 | 34 | 35 | flat_preorder_indices = [ 36 | np.array(x) for x in [[6, 4, 0, 1, 5, 2, 3], [6, 3, 5, 2, 4, 0, 1]] 37 | ] 38 | preorder_indices = flat_preorder_indices + [np.stack(flat_preorder_indices)] 39 | 40 | 41 | @pytest.mark.parametrize( 42 | ["child_indices", "expected_preorder_indices"], 43 | zip(child_indices, preorder_indices), 44 | ) 45 | def test_get_preorder_indices( 46 | child_indices: np.ndarray, expected_preorder_indices: np.ndarray 47 | ): 48 | preorder_indices = get_preorder_indices(child_indices) 49 | assert_equal(preorder_indices, expected_preorder_indices) 50 | -------------------------------------------------------------------------------- /test/tree/topology/test_tensorflow_topology.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | from treeflow.tree.topology.tensorflow_tree_topology import ( 4 | TensorflowTreeTopology, 5 | numpy_topology_to_tensor, 6 | ) 7 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 8 | from numpy.testing import assert_allclose 9 | import numpy as np 10 | 11 | 12 | @tf.function 13 | def function_tf(topology: TensorflowTreeTopology): 14 | return topology.child_indices[..., topology.taxon_count, :] 15 | 16 | 17 | def function_np(topology: NumpyTreeTopology): 18 | return topology.child_indices[..., topology.taxon_count, :] 19 | 20 | 21 | def test_TensorflowTreeTopology_function_arg(flat_tree_test_data): 22 | numpy_topology = NumpyTreeTopology( 23 | parent_indices=flat_tree_test_data.parent_indices 24 | ) 25 | tf_topology = numpy_topology_to_tensor(numpy_topology) 26 | res = function_tf(tf_topology) 27 | expected = function_np(numpy_topology) 28 | assert_allclose(res.numpy(), expected) 29 | 30 | 31 | def test_TensorflowTreeTopology_nest_map(flat_tree_test_data): 32 | numpy_topology = NumpyTreeTopology( 33 | parent_indices=flat_tree_test_data.parent_indices 34 | ) 35 | tf_topology = numpy_topology_to_tensor(numpy_topology) 36 | res = tf.nest.map_structure(tf.shape, tf_topology) 37 | 38 | 39 | def test_TensorflowTreeTopology_get_prefer_static_rank(flat_tree_test_data): 40 | numpy_topology = NumpyTreeTopology( 41 | parent_indices=flat_tree_test_data.parent_indices 42 | ) 43 | tf_topology = numpy_topology_to_tensor(numpy_topology) 44 | rank = tf_topology.get_prefer_static_rank() 45 | assert isinstance(rank.parent_indices, np.ndarray) 46 | assert isinstance(rank.preorder_indices, np.ndarray) 47 | assert isinstance(rank.child_indices, np.ndarray) 48 | assert rank.parent_indices == 1 49 | assert rank.preorder_indices == 1 50 | assert rank.child_indices == 2 51 | 52 | 53 | @pytest.mark.parametrize( 54 | ["rank", "expected"], 55 | [ 56 | ( 57 | TensorflowTreeTopology( 58 | parent_indices=1, preorder_indices=1, child_indices=2 59 | ), 60 | False, 61 | ), # Static shape, no batch 62 | ( 63 | TensorflowTreeTopology( 64 | parent_indices=1, preorder_indices=1, child_indices=4 65 | ), 66 | True, 67 | ), # Static shape, with batch 68 | ], 69 | ) # TODO: Tests for dynamic shape? 70 | def test_TensorflowTreeTopology_has_rank_to_has_batch_dimensions_static(rank, expected): 71 | res = TensorflowTreeTopology.rank_to_has_batch_dimensions(rank) 72 | assert res == expected 73 | -------------------------------------------------------------------------------- /test/vi/test_fixed_topology_advi.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import yaml 3 | import tensorflow as tf 4 | from tensorflow_probability.python.vi import fit_surrogate_posterior 5 | from treeflow.model.phylo_model import ( 6 | phylo_model_to_joint_distribution, 7 | PhyloModel, 8 | ) 9 | from treeflow.model.approximation import ( 10 | get_fixed_topology_inverse_autoregressive_flow_approximation, 11 | get_fixed_topology_mean_field_approximation, 12 | get_inverse_autoregressive_flow_approximation, 13 | ) 14 | 15 | approximation_function_and_kwargs = dict( 16 | mean_field=(get_fixed_topology_mean_field_approximation, dict()), 17 | iaf=( 18 | get_fixed_topology_inverse_autoregressive_flow_approximation, 19 | dict(hidden_units_per_layer=5), 20 | ), 21 | ) 22 | 23 | 24 | @pytest.mark.parametrize("approximation", approximation_function_and_kwargs.keys()) 25 | def test_fit_surrogate_posterior_n_samples( 26 | actual_model_file, hello_tensor_tree, hello_alignment, approximation 27 | ): 28 | with open(actual_model_file) as f: 29 | model_dict = yaml.safe_load(f) 30 | model = PhyloModel(model_dict) 31 | dist = phylo_model_to_joint_distribution(model, hello_tensor_tree, hello_alignment) 32 | 33 | encoded_sequences = hello_alignment.get_encoded_sequence_tensor( 34 | hello_tensor_tree.taxon_set 35 | ) 36 | pinned = dist.experimental_pin(alignment=encoded_sequences) 37 | approx_func, approx_kwargs = approximation_function_and_kwargs[approximation] 38 | approximation, variables_dict = approx_func( 39 | pinned, topology_pins=dict(tree=hello_tensor_tree.topology), **approx_kwargs 40 | ) 41 | optimizer = tf.optimizers.Adam(learning_rate=1e-2) 42 | num_steps = 11 43 | trace = fit_surrogate_posterior( 44 | pinned.unnormalized_log_prob, 45 | approximation, 46 | optimizer, 47 | num_steps, 48 | sample_size=3, 49 | importance_sample_size=4, 50 | ) 51 | assert tuple(trace.shape) == (num_steps,) 52 | -------------------------------------------------------------------------------- /test/vi/test_progress_bar.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import pytest 3 | import tensorflow as tf 4 | import tqdm 5 | from tensorflow_probability.python.math.minimize import minimize 6 | from treeflow.vi.progress_bar import make_progress_bar_trace_fn 7 | 8 | 9 | @pytest.mark.parametrize("update_step", [1, 3]) 10 | def test_make_progress_bar_trace_fn(update_step): 11 | x = tf.Variable(0.0) 12 | loss = lambda: tf.square(x) + 2 * x - 2 13 | 14 | trace_fn = lambda mtq: mtq.loss 15 | num_steps = 12 16 | 17 | tqdm_instance: tp.Optional[tqdm.tqdm] = None 18 | 19 | def make_tqdm(total): 20 | nonlocal tqdm_instance 21 | tqdm_instance = tqdm.tqdm(total=total) 22 | return tqdm_instance 23 | 24 | with make_progress_bar_trace_fn( 25 | trace_fn, num_steps, make_tqdm, update_step=update_step 26 | ) as progress_trace_fn: 27 | trace = minimize( 28 | loss, num_steps, tf.optimizers.Adam(), trace_fn=progress_trace_fn 29 | ) 30 | 31 | assert tqdm_instance is not None 32 | assert tqdm_instance.n == tqdm_instance.total 33 | assert tqdm_instance.disable 34 | -------------------------------------------------------------------------------- /treeflow/__init__.py: -------------------------------------------------------------------------------- 1 | """TreeFlow: automatic differentiation and probabilistic modelling with phylogenetic trees""" 2 | 3 | __version__ = "0.1" 4 | 5 | import os 6 | 7 | if os.getenv("TREEFLOW_SILENCE_TENSORFLOW", 1) == 1: 8 | print("Silencing TensorFlow...") 9 | from silence_tensorflow import silence_tensorflow 10 | 11 | silence_tensorflow() 12 | 13 | from treeflow.tf_util import ( 14 | DEFAULT_FLOAT_DTYPE_TF, 15 | DEFAULT_FLOAT_DTYPE_NP, 16 | float_constant, 17 | ) 18 | from treeflow.tree import parse_newick, convert_tree_to_tensor, write_tensor_trees 19 | from treeflow.evolution import ( 20 | Alignment, 21 | WeightedAlignment, 22 | AlignmentType, 23 | AlignmentFormat, 24 | ) 25 | from treeflow.model import PhyloModel, phylo_model_to_joint_distribution 26 | 27 | __all__ = [ 28 | "parse_newick", 29 | "convert_tree_to_tensor", 30 | "float_constant", 31 | "Alignment", 32 | "WeightedAlignment", 33 | "AlignmentType", 34 | "AlignmentFormat", 35 | "PhyloModel", 36 | "phylo_model_to_joint_distribution", 37 | "write_tensor_trees", 38 | ] 39 | -------------------------------------------------------------------------------- /treeflow/acceleration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/acceleration/__init__.py -------------------------------------------------------------------------------- /treeflow/acceleration/bito/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/acceleration/bito/__init__.py -------------------------------------------------------------------------------- /treeflow/acceleration/bito/instance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import typing as tp 3 | import bito 4 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 5 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 6 | 7 | 8 | def get_instance(newick_file, dated=True, name="treeflow"): 9 | 10 | inst = bito.rooted_instance(name) 11 | inst.read_newick_file(newick_file) 12 | if dated: 13 | inst.parse_dates_from_taxon_names(True) 14 | else: 15 | inst.set_dates_to_be_constant(True) 16 | return inst 17 | 18 | 19 | def get_tree_info(inst) -> tp.Tuple[NumpyRootedTree, np.ndarray]: 20 | bito_tree = inst.tree_collection.trees[0] 21 | parent_indices = np.array(bito_tree.parent_id_vector()) 22 | node_heights = np.array(bito_tree.node_heights) 23 | tree = NumpyRootedTree( 24 | heights=node_heights, 25 | topology=NumpyTreeTopology(parent_indices=parent_indices), 26 | ) 27 | 28 | node_bounds = np.array(bito_tree.node_bounds)[tree.taxon_count:] 29 | return ( 30 | tree, 31 | node_bounds, 32 | ) 33 | -------------------------------------------------------------------------------- /treeflow/acceleration/bito/ratio_transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import bito 4 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 5 | 6 | 7 | def ratios_to_node_heights_numpy(inst, x): 8 | tree = inst.tree_collection.trees[0] 9 | node_height_state = np.array(tree.node_heights, copy=False) 10 | tree.initialize_time_tree_using_height_ratios(x) 11 | return node_height_state[-x.shape[-1] :].astype(x.dtype) 12 | 13 | 14 | def ratio_gradient_numpy(inst, heights, dheights): 15 | tree = inst.tree_collection.trees[0] 16 | node_height_state = np.array(tree.node_heights, copy=False) 17 | node_height_state[-heights.shape[-1] :] = heights 18 | return np.array( 19 | bito.ratio_gradient_of_height_gradient(tree, dheights), 20 | dtype=heights.dtype, 21 | ) 22 | 23 | 24 | def ratios_to_node_heights(inst, anchor_heights, ratios): 25 | def numpy_func(ratios): 26 | return ratios_to_node_heights_numpy(inst, ratios) 27 | def numpy_grad_func(heights, dheights): 28 | return ratio_gradient_numpy(inst, heights, dheights) 29 | @tf.custom_gradient 30 | def bito_tf_func(x): 31 | heights = tf.numpy_function( 32 | numpy_func, 33 | [x], 34 | DEFAULT_FLOAT_DTYPE_TF, 35 | ) 36 | 37 | def grad(dheights): 38 | return tf.numpy_function( 39 | numpy_grad_func, 40 | [heights, dheights], 41 | DEFAULT_FLOAT_DTYPE_TF, 42 | ) 43 | 44 | return heights, grad 45 | 46 | with_root_height = tf.concat( 47 | [ratios[:-1], ratios[-1:] + anchor_heights[-1]], axis=0 48 | ) 49 | 50 | # Libsbn doesn't add root bound 51 | return bito_tf_func(with_root_height) 52 | -------------------------------------------------------------------------------- /treeflow/bijectors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/bijectors/__init__.py -------------------------------------------------------------------------------- /treeflow/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/cli/__init__.py -------------------------------------------------------------------------------- /treeflow/cli/inference_common.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import tensorflow as tf 3 | import tensorflow.keras.optimizers as keras_optimizers 4 | from treeflow.vi.optimizers.robust_optimizer import RobustOptimizer 5 | from treeflow.model.phylo_model import DEFAULT_TREE_VAR_NAME, PhyloModel 6 | from treeflow.tree.io import write_tensor_trees 7 | from treeflow.evolution.seqio import AlignmentFormat 8 | 9 | ADAM_KEY = "adam" 10 | ROBUST_ADAM_KEY = "robust_adam" 11 | optimizer_builders = { 12 | ADAM_KEY: keras_optimizers.Adam, 13 | ROBUST_ADAM_KEY: lambda *args, **kwargs: RobustOptimizer( 14 | keras_optimizers.Adam(*args, **kwargs) 15 | ), 16 | } 17 | 18 | ALIGNMENT_FORMATS = {format.value: format for format in AlignmentFormat} 19 | DEFAULT_ALIGNMENT_FORMAT = AlignmentFormat.FASTA.value 20 | 21 | EXAMPLE_PHYLO_MODEL_DICT = dict( 22 | tree=dict(coalescent=dict(pop_size=dict(exponential=dict(rate=0.1)))), 23 | clock=dict(strict=dict(clock_rate=dict(exponential=dict(rate=1000.0)))), 24 | substitution="jc", 25 | ) 26 | 27 | 28 | def parse_init_value(init_value_string: str) -> tp.Union[float, tp.List[float]]: 29 | split = [float(x) for x in init_value_string.split("|")] 30 | if len(split) == 1: 31 | return split[0] 32 | else: 33 | return split 34 | 35 | 36 | class InitialValueParseError(ValueError): 37 | pass 38 | 39 | 40 | def parse_init_values( 41 | init_values_string: str, model_names: tp.Optional[tp.Iterable[str]] = None 42 | ) -> tp.Dict[str, tf.Tensor]: 43 | try: 44 | str_dict = dict(item.split("=") for item in init_values_string.split(",")) 45 | res = {key: parse_init_value(value) for key, value in str_dict.items()} 46 | except ValueError as ex: 47 | raise InitialValueParseError(f"Error parsing initial values: {ex}") 48 | 49 | if model_names is not None: 50 | extra_keys = set(res.keys()).difference(model_names) 51 | if len(extra_keys) > 0: 52 | raise InitialValueParseError( 53 | f"Unknown parameters in initial values: {extra_keys}" 54 | ) 55 | 56 | return res 57 | 58 | 59 | def get_tree_vars(model: PhyloModel) -> tp.Set[str]: 60 | tree_vars = {DEFAULT_TREE_VAR_NAME} 61 | if model.relaxed_clock(): 62 | tree_vars.add("branch_rates") 63 | return tree_vars 64 | 65 | 66 | def write_trees( 67 | tree_var_samples: tp.Dict[str, tf.Tensor], topology_file, output_file 68 | ) -> None: 69 | branch_lengths = tree_var_samples.pop(DEFAULT_TREE_VAR_NAME).branch_lengths 70 | write_tensor_trees( 71 | topology_file, branch_lengths, output_file, branch_metadata=tree_var_samples 72 | ) 73 | -------------------------------------------------------------------------------- /treeflow/debug/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/debug/__init__.py -------------------------------------------------------------------------------- /treeflow/debug/nonfinite_convergence_criterion.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.optimizer.convergence_criteria import ( 3 | ConvergenceCriterion, 4 | ) 5 | 6 | 7 | def _any_nonfinite(x): 8 | nonfinite = tf.logical_not(tf.math.is_finite(x)) 9 | return tf.reduce_any(nonfinite) 10 | 11 | 12 | class NonfiniteConvergenceCriterion(ConvergenceCriterion): 13 | def __init__(self, name="NonfiniteConvergenceCriterion"): 14 | super().__init__(min_num_steps=0, name=name) 15 | 16 | def _bootstrap(self, loss, grads, parameters): 17 | return () 18 | 19 | def _one_step(self, step, loss, grads, parameters, auxiliary_state): 20 | loss_nonfinite = _any_nonfinite(loss) 21 | grads_nonfinite = [_any_nonfinite(x) for x in grads] 22 | parameters_nonfinite = [_any_nonfinite(x) for x in parameters] 23 | has_converged = tf.reduce_any( 24 | [loss_nonfinite] + grads_nonfinite + parameters_nonfinite 25 | ) 26 | return has_converged, () 27 | 28 | 29 | __all__ = [NonfiniteConvergenceCriterion.__name__] 30 | -------------------------------------------------------------------------------- /treeflow/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.distributions.tree.birthdeath import BirthDeathContemporarySampling, Yule 2 | from treeflow.distributions.discrete import FiniteDiscreteDistribution 3 | from treeflow.distributions.discretized import DiscretizedDistribution 4 | from treeflow.distributions.discrete_parameter_mixture import DiscreteParameterMixture 5 | from treeflow.distributions.leaf_ctmc import LeafCTMC 6 | from treeflow.distributions.sample_weighted import SampleWeighted 7 | 8 | 9 | __all__ = [ 10 | "BirthDeathContemporarySampling", 11 | "Yule" "FiniteDiscreteDistribution", 12 | "DiscretizedDistribution", 13 | "DiscreteParameterMixture", 14 | "LeafCTMC", 15 | "SampleWeighted", 16 | ] 17 | -------------------------------------------------------------------------------- /treeflow/distributions/discrete.py: -------------------------------------------------------------------------------- 1 | from typing_extensions import Protocol 2 | import tensorflow as tf 3 | 4 | 5 | class FiniteDiscreteDistribution(Protocol): 6 | """Interface for discrete probability distributions with a finite support""" 7 | 8 | @property 9 | def support(self) -> tf.Tensor: 10 | """Values that discrete distribution can take""" 11 | ... 12 | 13 | @property 14 | def normalised_support(self) -> tf.Tensor: 15 | """Support normalised to have a mean of 1 (weighted by probability mass)""" 16 | ... 17 | 18 | @property 19 | def probabilities(self): 20 | """Respective probability masses for support""" 21 | ... 22 | 23 | @property 24 | def support_size(self): 25 | """Number of elements in support""" 26 | ... 27 | -------------------------------------------------------------------------------- /treeflow/distributions/markov_chain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/distributions/markov_chain/__init__.py -------------------------------------------------------------------------------- /treeflow/distributions/sample_weighted.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.distributions.distribution import Distribution 3 | from tensorflow_probability.python.distributions.sample import Sample 4 | from tensorflow_probability.python.internal import prefer_static as ps 5 | from tensorflow_probability.python.internal.parameter_properties import ( 6 | ParameterProperties, 7 | ) 8 | 9 | 10 | class SampleWeighted(Sample): 11 | def __init__( 12 | self, 13 | distribution: Distribution, 14 | weights: tf.Tensor, 15 | sample_shape=(), 16 | validate_args=False, 17 | experimental_use_kahan_sum=False, 18 | name=None, 19 | ): 20 | """ 21 | Parameters 22 | ---------- 23 | weights: Tensor 24 | Tensor with shape `sample_shape` 25 | """ 26 | self.weights = weights 27 | parameters = dict(locals()) 28 | super().__init__( 29 | distribution=distribution, 30 | sample_shape=sample_shape, 31 | validate_args=validate_args, 32 | experimental_use_kahan_sum=experimental_use_kahan_sum, 33 | name=name, 34 | ) 35 | self._parameters = parameters 36 | 37 | def _finish_log_prob(self, lp, aux): 38 | (sample_ndims, extra_sample_ndims, batch_ndims) = aux 39 | 40 | # (1) Ensure lp is fully broadcast in the sample dims, i.e. ensure lp has 41 | # full sample shape in the sample axes, before we reduce. 42 | bcast_lp_shape = ps.broadcast_shape( 43 | ps.shape(lp), 44 | ps.concat( 45 | [ 46 | ps.ones([sample_ndims], tf.int32), 47 | ps.reshape(self.sample_shape, shape=[-1]), 48 | ps.ones([batch_ndims], tf.int32), 49 | ], 50 | axis=0, 51 | ), 52 | ) 53 | lp_b = tf.broadcast_to(lp, bcast_lp_shape) 54 | # (2) Make the final reduction. 55 | axis = ps.range(sample_ndims, sample_ndims + extra_sample_ndims) 56 | weights_b = tf.reshape( 57 | self.weights, 58 | tf.concat( 59 | [tf.shape(self.weights), tf.ones(batch_ndims, dtype=tf.int32)], axis=0 60 | ), 61 | ) 62 | return self._sum_fn()(lp_b * weights_b, axis=axis) 63 | 64 | @classmethod 65 | def _parameter_properties(cls, dtype, num_classes=None): 66 | return dict( 67 | Sample._parameter_properties(dtype, num_classes=num_classes), 68 | weights=ParameterProperties(event_ndims=1), 69 | ) 70 | 71 | 72 | __all__ = ["SampleWeighted"] 73 | -------------------------------------------------------------------------------- /treeflow/distributions/tree/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling import ( 2 | BirthDeathContemporarySampling, 3 | ) 4 | from treeflow.distributions.tree.birthdeath.yule import Yule 5 | from treeflow.distributions.tree.coalescent.constant_coalescent import ( 6 | ConstantCoalescent, 7 | ) 8 | 9 | __all__ = ["Yule", "BirthDeathContemporarySampling", "ConstantCoalescent"] 10 | -------------------------------------------------------------------------------- /treeflow/distributions/tree/birthdeath/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling import ( 2 | BirthDeathContemporarySampling, 3 | ) 4 | from treeflow.distributions.tree.birthdeath.yule import Yule 5 | -------------------------------------------------------------------------------- /treeflow/distributions/tree/birthdeath/yule.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import tensorflow as tf 3 | from treeflow.distributions.tree.birthdeath.birth_death_contemporary_sampling import ( 4 | BirthDeathContemporarySampling, 5 | ) 6 | 7 | 8 | class Yule(BirthDeathContemporarySampling): 9 | def __init__( 10 | self, 11 | taxon_count, 12 | birth_rate, 13 | validate_args=False, 14 | allow_nan_stats=True, 15 | name="Yule", 16 | tree_name: tp.Optional[str] = None, 17 | ): 18 | params = dict(locals()) 19 | super().__init__( 20 | taxon_count=taxon_count, 21 | birth_diff_rate=birth_rate, 22 | relative_death_rate=tf.zeros_like(birth_rate), 23 | validate_args=validate_args, 24 | allow_nan_stats=allow_nan_stats, 25 | name=name, 26 | tree_name=tree_name, 27 | ) 28 | self._parameters = params 29 | 30 | def _log_coeff(self, dtype): 31 | return tf.zeros((), dtype) 32 | 33 | @classmethod 34 | def _parameter_properties(cls, dtype, num_classes=None): 35 | super_pp = BirthDeathContemporarySampling._parameter_properties( 36 | dtype, num_classes=num_classes 37 | ) 38 | return dict(birth_rate=super_pp["birth_diff_rate"]) 39 | 40 | 41 | __all__ = ["Yule"] 42 | -------------------------------------------------------------------------------- /treeflow/distributions/tree/coalescent/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.distributions.tree.coalescent.constant_coalescent import * 2 | -------------------------------------------------------------------------------- /treeflow/evolution/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.evolution.seqio import ( 2 | Alignment, 3 | WeightedAlignment, 4 | AlignmentType, 5 | AlignmentFormat, 6 | ) 7 | from treeflow.evolution.calibration import * 8 | from treeflow.evolution.substitution import * 9 | 10 | __all__ = ["Alignment", "WeightedAlignment", "AlignmentType", "AlignmentFormat"] 11 | -------------------------------------------------------------------------------- /treeflow/evolution/calibration/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.evolution.calibration.calibration import MRCACalibrationSet 2 | -------------------------------------------------------------------------------- /treeflow/evolution/calibration/mrca.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 3 | 4 | 5 | def get_common_ancestors( 6 | topology: NumpyTreeTopology, indices: tp.Iterable[int] 7 | ) -> tp.Set[int]: 8 | ancestors: tp.Set[int] = set() 9 | node_count = topology.node_count 10 | for base_index in indices: 11 | ancestors_remaining = True 12 | index = base_index 13 | while ancestors_remaining and index < (node_count - 1): 14 | parent = topology.parent_indices[index] 15 | if parent in ancestors: 16 | ancestors_remaining = False 17 | ancestors = set([x for x in ancestors if x >= parent]) 18 | else: 19 | ancestors.add(parent) 20 | index = parent 21 | return ancestors 22 | 23 | 24 | def get_mrca_index(topology: NumpyTreeTopology, taxa: tp.Iterable[str]) -> int: 25 | assert topology.taxon_set is not None 26 | all_taxa = list(topology.taxon_set) 27 | indices = [all_taxa.index(taxon) for taxon in taxa] 28 | ancestors = get_common_ancestors(topology, indices) 29 | return min(ancestors) 30 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/__init__.py: -------------------------------------------------------------------------------- 1 | from .nucleotide import * 2 | from .probabilities import get_transition_probabilities_tree 3 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/base_substitution_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from abc import abstractmethod 3 | from treeflow.evolution.substitution.eigendecomposition import ( 4 | Eigendecomposition, 5 | ) 6 | 7 | 8 | def normalising_constant(q: tf.Tensor, pi: tf.Tensor) -> tf.Tensor: 9 | return -tf.reduce_sum(tf.linalg.diag_part(q) * pi) 10 | 11 | 12 | def normalise(q: tf.Tensor, pi: tf.Tensor) -> tf.Tensor: 13 | return q / normalising_constant(q, pi) 14 | 15 | 16 | class SubstitutionModel: 17 | @abstractmethod 18 | def q(self, frequencies: tf.Tensor, **kwargs: tf.Tensor) -> tf.Tensor: 19 | ... 20 | 21 | def q_norm(self, frequencies: tf.Tensor, **kwargs: tf.Tensor) -> tf.Tensor: 22 | return normalise(self.q(frequencies, **kwargs), frequencies) 23 | 24 | 25 | class EigendecompositionSubstitutionModel( 26 | SubstitutionModel 27 | ): # TODO: Rename class to time reversible, method to diagonalisation? 28 | def eigen(self, frequencies: tf.Tensor, **kwargs: tf.Tensor) -> Eigendecomposition: 29 | """Eigendecomposition of the normalised instantaneous rate matrix""" 30 | 31 | # First create a symmetric matrix from the normalised Q matrix 32 | q_norm = self.q_norm(frequencies=frequencies, **kwargs) 33 | sqrt_frequencies = tf.math.sqrt(frequencies) 34 | inverse_sqrt_frequencies = 1.0 / sqrt_frequencies 35 | 36 | sqrt_frequencies_diag_matrix = tf.linalg.diag(sqrt_frequencies) 37 | inverse_sqrt_frequencies_diag_matrix = tf.linalg.diag(inverse_sqrt_frequencies) 38 | 39 | symmetric_matrix = ( 40 | sqrt_frequencies_diag_matrix @ q_norm @ inverse_sqrt_frequencies_diag_matrix 41 | ) 42 | eigenvalues, s_eigenvectors = tf.linalg.eigh(symmetric_matrix) 43 | eigenvectors = inverse_sqrt_frequencies_diag_matrix @ s_eigenvectors 44 | inverse_eigenvectors = ( 45 | tf.linalg.matrix_transpose(s_eigenvectors) @ sqrt_frequencies_diag_matrix 46 | ) 47 | 48 | return Eigendecomposition( 49 | eigenvectors=eigenvectors, 50 | inverse_eigenvectors=inverse_eigenvectors, 51 | eigenvalues=eigenvalues, 52 | ) 53 | 54 | 55 | __all__ = [SubstitutionModel.__name__, EigendecompositionSubstitutionModel.__name__] 56 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/eigendecomposition.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import attr 4 | import tensorflow as tf 5 | import tensorflow.python.util.nest as nest 6 | 7 | 8 | @attr.s(auto_attribs=True, slots=True) 9 | class Eigendecomposition: 10 | """ 11 | Eigendecomposition of an instantaneous rate matrix 12 | 13 | Attributes 14 | ---------- 15 | eigenvectors 16 | 2D Tensor with right eigenvectors as columns 17 | inverse_eigenvectors 18 | 2D Tensor, inverse of `eigenvectors` 19 | eigenvalues 20 | 1D Tensor of eigenvalues 21 | """ 22 | 23 | eigenvectors: tf.Tensor 24 | inverse_eigenvectors: tf.Tensor 25 | eigenvalues: tf.Tensor 26 | 27 | def add_inner_batch_dimensions( 28 | self, 29 | batch_dims: int = 1, 30 | inner_batch_rank: int = 0, 31 | ) -> Eigendecomposition: 32 | """ 33 | Add batch dimensions before the state dimensions 34 | """ 35 | # TODO: Reimplement with reshape 36 | assert batch_dims >= 0 37 | if batch_dims > 0: 38 | return nest.map_structure( 39 | lambda x, dim: tf.expand_dims(x, axis=dim), 40 | self.add_inner_batch_dimensions( 41 | batch_dims - 1, inner_batch_rank=inner_batch_rank 42 | ), 43 | Eigendecomposition( 44 | eigenvectors=-3 - inner_batch_rank, 45 | inverse_eigenvectors=-3 - inner_batch_rank, 46 | eigenvalues=-2 - inner_batch_rank, 47 | ), 48 | ) 49 | else: 50 | return self 51 | 52 | 53 | __all__ = [Eigendecomposition.__name__] 54 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/nucleotide/__init__.py: -------------------------------------------------------------------------------- 1 | from .hky import * 2 | from .jc import * 3 | from .gtr import * 4 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/nucleotide/alphabet.py: -------------------------------------------------------------------------------- 1 | A, C, G, T = range(4) 2 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/nucleotide/gtr.py: -------------------------------------------------------------------------------- 1 | from struct import pack 2 | import tensorflow as tf 3 | from treeflow.evolution.substitution.eigendecomposition import Eigendecomposition 4 | from treeflow.evolution.substitution.base_substitution_model import ( 5 | EigendecompositionSubstitutionModel, 6 | ) 7 | from treeflow.evolution.substitution.util import pack_matrix 8 | 9 | GTR_RATE_ORDER = ("ac", "ag", "at", "cg", "ct", "gt") 10 | 11 | 12 | class GTR(EigendecompositionSubstitutionModel): 13 | def q(self, frequencies: tf.Tensor, rates: tf.Tensor) -> tf.Tensor: 14 | pi = frequencies 15 | return pack_matrix( 16 | [ 17 | [ 18 | -( 19 | rates[..., 0] * pi[..., 1] 20 | + rates[..., 1] * pi[..., 2] 21 | + rates[..., 2] * pi[..., 3] 22 | ), 23 | rates[..., 0] * pi[..., 1], 24 | rates[..., 1] * pi[..., 2], 25 | rates[..., 2] * pi[..., 3], 26 | ], 27 | [ 28 | rates[..., 0] * pi[..., 0], 29 | -( 30 | rates[..., 0] * pi[..., 0] 31 | + rates[..., 3] * pi[..., 2] 32 | + rates[..., 4] * pi[..., 3] 33 | ), 34 | rates[..., 3] * pi[..., 2], 35 | rates[..., 4] * pi[..., 3], 36 | ], 37 | [ 38 | rates[..., 1] * pi[..., 0], 39 | rates[..., 3] * pi[..., 1], 40 | -( 41 | rates[..., 1] * pi[..., 0] 42 | + rates[..., 3] * pi[..., 1] 43 | + rates[..., 5] * pi[..., 3] 44 | ), 45 | rates[..., 5] * pi[..., 3], 46 | ], 47 | [ 48 | rates[..., 2] * pi[..., 0], 49 | rates[..., 4] * pi[..., 1], 50 | rates[..., 5] * pi[..., 2], 51 | -( 52 | rates[..., 2] * pi[..., 0] 53 | + rates[..., 4] * pi[..., 1] 54 | + rates[..., 5] * pi[..., 2] 55 | ), 56 | ], 57 | ], 58 | ) 59 | 60 | 61 | __all__ = ["GTR"] 62 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/nucleotide/hky.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from treeflow.evolution.substitution.eigendecomposition import Eigendecomposition 3 | from treeflow.evolution.substitution.base_substitution_model import ( 4 | EigendecompositionSubstitutionModel, 5 | ) 6 | from treeflow.evolution.substitution.nucleotide.alphabet import A, C, G, T 7 | from treeflow.evolution.substitution.util import pack_matrix, pack_matrix_transposed 8 | 9 | 10 | class HKY(EigendecompositionSubstitutionModel): 11 | def q(self, frequencies: tf.Tensor, kappa: tf.Tensor) -> tf.Tensor: 12 | pi = frequencies 13 | 14 | piA = pi[..., A] 15 | piC = pi[..., C] 16 | piT = pi[..., T] 17 | piG = pi[..., G] 18 | return pack_matrix( 19 | [ 20 | [-(piC + kappa * piG + piT), piC, kappa * piG, piT], 21 | [piA, -(piA + piG + kappa * piT), piG, kappa * piT], 22 | [kappa * piA, piC, -(kappa * piA + piC + piT), piT], 23 | [piA, kappa * piC, piG, -(piA + kappa * piC + piG)], 24 | ], 25 | ) 26 | 27 | def eigen(self, frequencies: tf.Tensor, kappa: tf.Tensor) -> Eigendecomposition: 28 | pi = frequencies 29 | piA = pi[..., A] 30 | piC = pi[..., C] 31 | piT = pi[..., T] 32 | piG = pi[..., G] 33 | piY = piT + piC 34 | piR = piA + piG 35 | 36 | beta = -1.0 / (2.0 * (piR * piY + kappa * (piA * piG + piC * piT))) 37 | 38 | one = tf.ones_like(kappa) 39 | zero = tf.zeros_like(kappa) 40 | minus_one = -one 41 | 42 | eigenvalues = tf.stack( 43 | [ 44 | zero, 45 | beta, 46 | beta * (piY * kappa + piR), 47 | beta * (piY + piR * kappa), 48 | ], 49 | axis=-1, 50 | ) 51 | eigenvectors = pack_matrix_transposed( 52 | [ 53 | [one, one, one, one], 54 | [1.0 / piR, -1.0 / piY, 1.0 / piR, -1.0 / piY], 55 | [zero, piT / piY, zero, -piC / piY], 56 | [piG / piR, zero, -piA / piR, zero], 57 | ], 58 | ) 59 | inverse_eigenvectors = pack_matrix( 60 | [ 61 | [piA, piC, piG, piT], 62 | [piA * piY, -piC * piR, piG * piY, -piT * piR], 63 | [zero, one, zero, minus_one], 64 | [one, zero, minus_one, zero], 65 | ], 66 | ) 67 | 68 | return Eigendecomposition( 69 | eigenvalues=eigenvalues, 70 | eigenvectors=eigenvectors, 71 | inverse_eigenvectors=inverse_eigenvectors, 72 | ) 73 | 74 | 75 | __all__ = [HKY.__name__] 76 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/nucleotide/jc.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from treeflow import DEFAULT_FLOAT_DTYPE_TF 3 | import tensorflow as tf 4 | import numpy as np 5 | from treeflow.evolution.substitution.eigendecomposition import Eigendecomposition 6 | from treeflow.evolution.substitution.base_substitution_model import ( 7 | EigendecompositionSubstitutionModel, 8 | ) 9 | 10 | 11 | class JC(EigendecompositionSubstitutionModel): 12 | @staticmethod 13 | def frequencies(dtype=DEFAULT_FLOAT_DTYPE_TF): 14 | return tf.constant([1 / 4] * 4, dtype=dtype) 15 | 16 | def q(self, frequencies: tf.Tensor, dtype=DEFAULT_FLOAT_DTYPE_TF) -> tf.Tensor: 17 | return tf.constant( 18 | [ 19 | [-1, 1 / 3, 1 / 3, 1 / 3], 20 | [1 / 3, -1, 1 / 3, 1 / 3], 21 | [1 / 3, 1 / 3, -1, 1 / 3], 22 | [1 / 3, 1 / 3, 1 / 3, -1], 23 | ], 24 | dtype=dtype, 25 | ) 26 | 27 | def eigen( 28 | self, 29 | frequencies: tp.Optional[tf.Tensor] = None, 30 | dtype: tf.DType = DEFAULT_FLOAT_DTYPE_TF, 31 | ) -> Eigendecomposition: 32 | return Eigendecomposition( 33 | eigenvectors=tf.constant( 34 | [ 35 | [1.0, 2.0, 0.0, 0.5], 36 | [1.0, -2.0, 0.5, 0.0], 37 | [1.0, 2.0, 0.0, -0.5], 38 | [1.0, -2.0, -0.5, 0.0], 39 | ], 40 | dtype=dtype, 41 | ), 42 | eigenvalues=tf.constant( 43 | [0.0, -1.3333333333333333, -1.3333333333333333, -1.3333333333333333], 44 | dtype=dtype, 45 | ), 46 | inverse_eigenvectors=tf.constant( 47 | [ 48 | [0.25, 0.25, 0.25, 0.25], 49 | [0.125, -0.125, 0.125, -0.125], 50 | [0.0, 1.0, 0.0, -1.0], 51 | [1.0, 0.0, -1.0, 0.0], 52 | ], 53 | dtype=dtype, 54 | ), 55 | ) 56 | 57 | 58 | __all__ = ["JC"] 59 | -------------------------------------------------------------------------------- /treeflow/evolution/substitution/util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def pack_matrix(mat): 5 | return tf.stack([tf.stack(row, axis=-1) for row in mat], axis=-2) 6 | 7 | 8 | def pack_matrix_transposed(mat): 9 | return tf.stack([tf.stack(col, axis=-1) for col in mat], axis=-1) 10 | -------------------------------------------------------------------------------- /treeflow/model/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.model.phylo_model import PhyloModel, phylo_model_to_joint_distribution 2 | 3 | __all__ = [ 4 | "PhyloModel", 5 | "phylo_model_to_joint_distribution", 6 | ] 7 | -------------------------------------------------------------------------------- /treeflow/model/approximation/__init__.py: -------------------------------------------------------------------------------- 1 | from .mean_field import ( 2 | get_mean_field_approximation, 3 | get_fixed_topology_mean_field_approximation, 4 | ) 5 | from .iaf import ( 6 | get_inverse_autoregressive_flow_approximation, 7 | get_fixed_topology_inverse_autoregressive_flow_approximation, 8 | ) 9 | -------------------------------------------------------------------------------- /treeflow/model/structured_approximation.py: -------------------------------------------------------------------------------- 1 | def normal_natural_params_to_loc_and_scale(): 2 | pass 3 | -------------------------------------------------------------------------------- /treeflow/tf_util/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.tf_util.dtype_util import ( 2 | DEFAULT_FLOAT_DTYPE_TF, 3 | DEFAULT_FLOAT_DTYPE_NP, 4 | float_constant, 5 | ) 6 | from treeflow.tf_util.attrs import AttrsLengthMixin 7 | 8 | __all__ = [ 9 | "DEFAULT_FLOAT_DTYPE_TF", 10 | "DEFAULT_FLOAT_DTYPE_NP", 11 | "float_constant", 12 | "AttrsLengthMixin", 13 | ] 14 | -------------------------------------------------------------------------------- /treeflow/tf_util/attrs.py: -------------------------------------------------------------------------------- 1 | import tensorflow.python.util.nest as nest 2 | import warnings 3 | 4 | 5 | class AttrsLengthMixin: 6 | def __len__(self) -> int: 7 | warnings.warn("Temporary hotfix") 8 | assert nest._is_attrs(self) 9 | return len(nest._get_attrs_items(self)) 10 | 11 | 12 | __all__ = ["AttrsLengthMixin"] 13 | -------------------------------------------------------------------------------- /treeflow/tf_util/dtype_util.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | DEFAULT_FLOAT_DTYPE_TF = tf.float64 6 | DEFAULT_FLOAT_DTYPE_NP = np.float64 7 | 8 | 9 | def float_constant(x: tp.Union[float, np.ndarray, tp.Iterable[float]]): 10 | """ 11 | Converts a floating point value or array to a constant Tensor with TreeFlow's 12 | default data type 13 | 14 | Parameters 15 | ---------- 16 | x 17 | Value that can be converted to a Tensor 18 | 19 | Returns 20 | ------- 21 | tf.Tensor 22 | Value converted to a constant tensor 23 | """ 24 | return tf.constant(x, dtype=DEFAULT_FLOAT_DTYPE_TF) 25 | 26 | 27 | __all__ = ["DEFAULT_FLOAT_DTYPE_TF", "DEFAULT_FLOAT_DTYPE_NP", "float_constant"] 28 | -------------------------------------------------------------------------------- /treeflow/tf_util/vectorize.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.internal import prefer_static as ps 3 | from tensorflow_probability.python.internal import distribution_util 4 | 5 | 6 | def broadcast_structure(elems, event_shape, batch_shape): 7 | return tf.nest.map_structure( 8 | lambda elem, elem_event_shape: tf.broadcast_to( 9 | elem, tf.concat([batch_shape, elem_event_shape], 0) 10 | ), 11 | elems, 12 | event_shape, 13 | ) 14 | 15 | 16 | def reshape_structure(elems, event_shape, new_batch_shape): 17 | return tf.nest.map_structure( 18 | lambda elem, elem_event_shape: tf.reshape( 19 | elem, tf.concat([new_batch_shape, elem_event_shape], 0) 20 | ), 21 | elems, 22 | event_shape, 23 | ) 24 | 25 | 26 | def vectorize_over_batch_dims( 27 | fn, elems, event_shape, batch_shape, vectorized_map=True, fn_output_signature=None 28 | ): 29 | flat_batch_shape = tf.expand_dims(ps.reduce_prod(batch_shape), 0) 30 | flat_structure = reshape_structure(elems, event_shape, flat_batch_shape) 31 | if vectorized_map: 32 | result = tf.vectorized_map(fn, flat_structure, fallback_to_while_loop=False) 33 | else: 34 | assert fn_output_signature is not None 35 | result = tf.map_fn(fn, flat_structure, fn_output_signature=fn_output_signature) 36 | new_event_shape = tf.nest.map_structure(lambda elem: tf.shape(elem)[1:], result) 37 | return reshape_structure(result, new_event_shape, batch_shape) 38 | -------------------------------------------------------------------------------- /treeflow/traversal/__init__.py: -------------------------------------------------------------------------------- 1 | """Tree-traversal based computations""" 2 | -------------------------------------------------------------------------------- /treeflow/traversal/anchor_heights.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 3 | from treeflow.tree.topology.tensorflow_tree_topology import TensorflowTreeTopology 4 | import tensorflow as tf 5 | 6 | 7 | def get_anchor_heights(tree: NumpyRootedTree) -> np.ndarray: 8 | taxon_count = tree.taxon_count 9 | anchor_heights = np.zeros_like(tree.heights) 10 | anchor_heights[..., :taxon_count] = tree.heights[..., :taxon_count] 11 | 12 | for i in tree.topology.postorder_node_indices: 13 | anchor_heights[..., i] = np.max( 14 | anchor_heights[..., tree.topology.child_indices[i]], axis=-1 15 | ) 16 | 17 | return anchor_heights[..., taxon_count:] 18 | 19 | 20 | @tf.function 21 | def get_anchor_heights_tensor( 22 | topology: TensorflowTreeTopology, sampling_times: tf.Tensor 23 | ): 24 | taxon_count = topology.taxon_count 25 | anchor_heights = tf.TensorArray(sampling_times.dtype, size=taxon_count * 2 - 1) 26 | for i in tf.range(taxon_count): 27 | anchor_heights = anchor_heights.write(i, sampling_times[..., i]) 28 | child_indices = topology.child_indices 29 | for i in topology.postorder_node_indices: 30 | child_anchor_heights = anchor_heights.gather(child_indices[i]) 31 | anchor_heights = anchor_heights.write( 32 | i, tf.reduce_max(child_anchor_heights, axis=0) 33 | ) 34 | 35 | rank = tf.shape(tf.shape(sampling_times))[0] 36 | perm = tf.concat([tf.range(1, rank), [0]], axis=0) 37 | return tf.transpose(anchor_heights.gather(topology.postorder_node_indices), perm) 38 | 39 | 40 | __all__ = [get_anchor_heights.__name__, get_anchor_heights_tensor.__name__] 41 | -------------------------------------------------------------------------------- /treeflow/traversal/postorder.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import attr 3 | import tensorflow as tf 4 | from treeflow.tree.topology.tensorflow_tree_topology import TensorflowTreeTopology 5 | 6 | 7 | TInputStructure = tp.TypeVar("TInputStructure") 8 | TOutputStructure = tp.TypeVar("TOutputStructure") 9 | 10 | 11 | @attr.attrs(auto_attribs=True) 12 | class PostorderTopologyData: 13 | child_indices: tf.Tensor 14 | 15 | 16 | def postorder_node_traversal( 17 | topology: TensorflowTreeTopology, 18 | mapping: tp.Callable[ 19 | [TOutputStructure, TInputStructure, PostorderTopologyData], TOutputStructure 20 | ], 21 | input: TInputStructure, 22 | leaf_init: TOutputStructure, 23 | ) -> TOutputStructure: 24 | taxon_count = topology.taxon_count 25 | node_count = 2 * taxon_count - 1 26 | tensorarrays = tf.nest.map_structure( 27 | lambda x: tf.TensorArray( 28 | dtype=x.dtype, 29 | size=node_count, 30 | element_shape=x.shape[1:], 31 | clear_after_read=False, 32 | ), 33 | leaf_init, 34 | ) 35 | for i in tf.range(taxon_count): 36 | tensorarrays = tf.nest.map_structure( 37 | lambda x, ta: ta.write(i, x[i]), leaf_init, tensorarrays 38 | ) 39 | postorder_node_indices = topology.postorder_node_indices 40 | child_indices = topology.child_indices 41 | for i in tf.range(taxon_count - 1): 42 | node_index = postorder_node_indices[i] 43 | node_child_indices = child_indices[node_index] 44 | child_output = tf.nest.map_structure( 45 | lambda ta: ta.gather(node_child_indices), tensorarrays 46 | ) 47 | node_input = tf.nest.map_structure(lambda x: x[node_index - taxon_count], input) 48 | topology_data = PostorderTopologyData(child_indices=node_child_indices) 49 | output = mapping(child_output, node_input, topology_data) 50 | tensorarrays = tf.nest.map_structure( 51 | lambda x, ta: ta.write(node_index, x), output, tensorarrays 52 | ) 53 | return tf.nest.map_structure(lambda x: x.stack(), tensorarrays) 54 | -------------------------------------------------------------------------------- /treeflow/traversal/preorder.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import tensorflow as tf 3 | from treeflow.tree.topology.tensorflow_tree_topology import TensorflowTreeTopology 4 | 5 | TInputStructure = tp.TypeVar("TInputStructure") 6 | TOutputStructure = tp.TypeVar("TOutputStructure") 7 | 8 | 9 | @tf.function 10 | def preorder_traversal( 11 | topology: TensorflowTreeTopology, 12 | mapping: tp.Callable[[TOutputStructure, TInputStructure], TOutputStructure], 13 | input: TInputStructure, 14 | root_init: TOutputStructure, 15 | ) -> TOutputStructure: 16 | taxon_count = topology.taxon_count 17 | node_count = taxon_count - 1 18 | tensorarrays = tf.nest.map_structure( 19 | lambda x: tf.TensorArray( 20 | dtype=x.dtype, 21 | size=node_count, 22 | element_shape=x.shape, 23 | clear_after_read=False, 24 | ), 25 | root_init, 26 | ) 27 | tensorarrays = tf.nest.map_structure( 28 | lambda x, ta: ta.write(node_count - 1, x), root_init, tensorarrays 29 | ) 30 | 31 | parent_indices = topology.parent_indices[topology.taxon_count :] - taxon_count 32 | for i in topology.preorder_node_indices[1:] - taxon_count: 33 | parent_index = parent_indices[i] 34 | parent_output = tf.nest.map_structure( 35 | lambda ta: ta.read(parent_index), tensorarrays 36 | ) 37 | node_input = tf.nest.map_structure(lambda x: x[i], input) 38 | output = mapping(parent_output, node_input) 39 | tensorarrays = tf.nest.map_structure( 40 | lambda x, ta: ta.write(i, x), output, tensorarrays 41 | ) 42 | 43 | return tf.nest.map_structure(lambda x: x.stack(), tensorarrays) 44 | -------------------------------------------------------------------------------- /treeflow/traversal/ratio_transform.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.internal import prefer_static as ps 3 | 4 | 5 | def move_outside_axis_to_inside(x): 6 | rank = tf.rank(x) 7 | perm = tf.concat([tf.range(1, rank), [0]], axis=0) 8 | return tf.transpose(x, perm) 9 | 10 | 11 | @tf.function 12 | def ratios_to_node_heights( 13 | preorder_node_indices: tf.Tensor, 14 | parent_indices: tf.Tensor, 15 | ratios: tf.Tensor, 16 | anchor_heights: tf.Tensor, 17 | ): 18 | node_count = ratios.shape[-1] 19 | node_heights_ta = tf.TensorArray( 20 | dtype=ratios.dtype, 21 | size=node_count, 22 | element_shape=ratios.shape[:-1], 23 | clear_after_read=False, 24 | ) 25 | 26 | node_heights_ta = node_heights_ta.write( 27 | node_count - 1, 28 | ratios[..., node_count - 1] + anchor_heights[..., node_count - 1], 29 | ) 30 | for i in preorder_node_indices[1:]: 31 | parent_height = node_heights_ta.read(parent_indices[i]) 32 | anchor_height = anchor_heights[..., i] 33 | proportion = ratios[..., i] 34 | node_height = (parent_height - anchor_height) * proportion + anchor_height 35 | node_heights_ta = node_heights_ta.write(i, node_height) 36 | 37 | return move_outside_axis_to_inside(node_heights_ta.stack()) 38 | -------------------------------------------------------------------------------- /treeflow/tree/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for representing trees in Tensorflow 3 | 4 | We use `attr` based classes as they are supported by `tf.nest`. 5 | Custom behaviour can be added subclassing `attr` classes as long as 6 | support for the standard constructor argument order is preserved. 7 | """ 8 | 9 | from treeflow.tree.taxon_set import DictTaxonSet, TupleTaxonSet 10 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 11 | from treeflow.tree.rooted.numpy_rooted_tree import NumpyRootedTree 12 | from treeflow.tree.rooted.tensorflow_rooted_tree import ( 13 | TensorflowRootedTree, 14 | convert_tree_to_tensor, 15 | tree_from_arrays, 16 | ) 17 | from treeflow.tree.io import parse_newick, write_tensor_trees 18 | 19 | __all__ = ["parse_newick", "convert_tree_to_tensor", "write_tensor_trees"] 20 | -------------------------------------------------------------------------------- /treeflow/tree/base_tree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from abc import abstractproperty, abstractmethod 3 | from treeflow.tree.taxon_set import TaxonSet 4 | import typing as tp 5 | from treeflow.tree.topology.base_tree_topology import ( 6 | AbstractTreeTopology, 7 | ) 8 | 9 | 10 | TDataType = tp.TypeVar("TDataType") 11 | TShapeType = tp.TypeVar("TShapeType") 12 | 13 | 14 | class AbstractTree(tp.Generic[TDataType, TShapeType]): 15 | @abstractproperty 16 | def topology(self) -> AbstractTreeTopology[TDataType, TShapeType]: 17 | pass 18 | 19 | @abstractproperty 20 | def branch_lengths(self) -> TDataType: 21 | pass 22 | 23 | @property 24 | def taxon_set(self) -> tp.Optional[TaxonSet]: 25 | return self.topology.taxon_set 26 | 27 | @property 28 | def taxon_count(self) -> TShapeType: 29 | return self.topology.taxon_count 30 | 31 | @abstractmethod 32 | def get_unrooted_tree(self) -> AbstractTree: 33 | pass 34 | -------------------------------------------------------------------------------- /treeflow/tree/rooted/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/tree/rooted/__init__.py -------------------------------------------------------------------------------- /treeflow/tree/rooted/base_rooted_tree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import abstractproperty 4 | from treeflow.tree.base_tree import AbstractTree 5 | from treeflow.tree.topology.base_tree_topology import ( 6 | AbstractTreeTopology, 7 | BaseTreeTopology, 8 | ) 9 | from treeflow.tree.unrooted.base_unrooted_tree import BaseUnrootedTree 10 | import attr 11 | import typing as tp 12 | from treeflow.tree.taxon_set import TaxonSet 13 | from treeflow.tf_util import AttrsLengthMixin 14 | 15 | TDataType = tp.TypeVar("TDataType") 16 | TShapeType = tp.TypeVar("TShapeType") 17 | TUnrootedTreeType = tp.TypeVar("TUnrootedTreeType") 18 | 19 | 20 | @attr.s(auto_attribs=True, slots=True) 21 | class BaseRootedTree(tp.Generic[TDataType], AttrsLengthMixin): 22 | topology: BaseTreeTopology[TDataType] 23 | node_heights: TDataType 24 | sampling_times: TDataType 25 | 26 | 27 | @attr.s(auto_attribs=True, slots=True, init=False) 28 | class AbstractRootedTreeAttrs( 29 | BaseRootedTree[TDataType], 30 | AbstractTree[TDataType, TShapeType], 31 | tp.Generic[TDataType, TShapeType], 32 | ): 33 | topology: AbstractTreeTopology[TDataType, TShapeType] 34 | 35 | def __init__( 36 | self, 37 | tree_or_first_arg: tp.Optional[ 38 | tp.Union[AbstractRootedTreeAttrs, object] 39 | ] = None, 40 | *args, 41 | **kwargs, 42 | ): # This logic is because `tf.nest.cast_structure` expects a copy constructor 43 | if isinstance(tree_or_first_arg, BaseRootedTree): 44 | self.__attrs_init__( 45 | topology=tree_or_first_arg.topology, 46 | node_heights=tree_or_first_arg.node_heights, 47 | sampling_times=tree_or_first_arg.sampling_times, 48 | ) 49 | elif attr.fields(type(self))[0].name in kwargs: 50 | self.__attrs_init__(*args, **kwargs) 51 | else: 52 | self.__attrs_init__(tree_or_first_arg, *args, **kwargs) 53 | 54 | 55 | class AbstractRootedTree( 56 | AbstractRootedTreeAttrs[TDataType, TShapeType], 57 | tp.Generic[TDataType, TShapeType, TUnrootedTreeType], 58 | ): 59 | 60 | UnrootedTreeType: tp.Type = BaseUnrootedTree # TODO: Better type hint 61 | 62 | @abstractproperty 63 | def heights(self) -> TDataType: 64 | pass 65 | 66 | def get_unrooted_tree(self) -> TUnrootedTreeType: 67 | return type(self).UnrootedTreeType( 68 | topology=self.topology, branch_lengths=self.branch_lengths 69 | ) 70 | -------------------------------------------------------------------------------- /treeflow/tree/taxon_set.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import typing as tp 3 | from typing_extensions import Protocol 4 | 5 | 6 | class TaxonSet(Protocol): 7 | """ 8 | Interface representing a taxon set. 9 | 10 | A TaxonSet is an ordered set of taxon names, usually associated with 11 | the ordered leaves of a phylogenetic tree. 12 | """ 13 | 14 | def __init__(self, taxa: tp.Iterable[str]): 15 | ... 16 | 17 | def __eq__(self, o: object) -> bool: 18 | ... 19 | 20 | def __len__(self) -> int: 21 | ... 22 | 23 | def __iter__(self) -> tp.Iterator[str]: 24 | ... 25 | 26 | def __contains__(self, value: str): 27 | ... 28 | 29 | 30 | class DictTaxonSet(tp.Dict[str, None]): 31 | """ 32 | Taxon set implementation based on the built-in dictionary type. 33 | 34 | An TaxonSet is an ordered set of taxon names, usually associated with 35 | the ordered leaves of a phylogenetic tree. 36 | The dictionary is used as an ordered set - the keys are the taxon names. 37 | """ 38 | 39 | def __init__(self, taxa: tp.Iterable[str]): 40 | super().__init__([(taxon, None) for taxon in taxa]) 41 | 42 | def __eq__(self, o: object) -> bool: 43 | if isinstance(o, tp.Iterable): 44 | return tuple(self) == tuple(o) 45 | else: 46 | return False 47 | 48 | def __ne__(self, o: object) -> bool: 49 | if isinstance(o, tp.Iterable): 50 | return tuple(self) != tuple(o) 51 | else: 52 | return True 53 | 54 | def __repr__(self) -> str: 55 | return repr(tuple(self.keys())) 56 | 57 | def __str__(self) -> str: 58 | return str(tuple(self.keys())) 59 | 60 | 61 | class TupleTaxonSet(tp.Tuple[str, ...]): 62 | """ 63 | Taxon set implementation based on the built-in tuple type. 64 | 65 | An TaxonSet is an ordered set of taxon names, usually associated with 66 | the ordered leaves of a phylogenetic tree. 67 | This implementation is required for some TensorFlow functionality. 68 | """ 69 | 70 | pass 71 | 72 | 73 | __all__ = ["TaxonSet", "DictTaxonSet", "TupleTaxonSet"] 74 | -------------------------------------------------------------------------------- /treeflow/tree/topology/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/tree/topology/__init__.py -------------------------------------------------------------------------------- /treeflow/tree/topology/base_tree_topology.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | from abc import abstractproperty 4 | from treeflow.tree.taxon_set import TaxonSet 5 | 6 | TDataType = tp.TypeVar("TDataType") 7 | TShapeType = tp.TypeVar("TShapeType") 8 | import attr 9 | 10 | 11 | @attr.s(auto_attribs=True, slots=True) 12 | class BaseTreeTopology(tp.Generic[TDataType]): 13 | parent_indices: TDataType 14 | 15 | 16 | class AbstractTreeTopology( 17 | BaseTreeTopology[TDataType], tp.Generic[TDataType, TShapeType] 18 | ): 19 | """ 20 | Class representing a bifurcating tree topology as a composition of integer arrays. 21 | 22 | For a phylogenetic tree with ``n`` taxa at the leaves, the representation 23 | maintains a labelling of the ``1n-1`` nodes with integer indices. The labelling 24 | convention is that the leaves are the first ``n`` indices and the root is at the last 25 | index (``2n-2``). 26 | """ 27 | 28 | @abstractproperty 29 | def child_indices(self) -> TDataType: 30 | """ 31 | Array of length ``2n-2`` representing the parent-child structure of a 32 | tree topology on ``n`` taxa. 33 | 34 | The ``i`` th element of this array is the index of the parent of the ``i`` th 35 | indexed node in the tree, for every node except the root. 36 | """ 37 | pass 38 | 39 | @abstractproperty 40 | def preorder_indices(self) -> TDataType: 41 | """ 42 | Array of indices of length ``2n-1`` that form a pre-order traversal 43 | of the tree. 44 | 45 | A pre-order traversal is an ordering of the nodes where every 46 | node is visited before its children, starting at the root. 47 | """ 48 | pass 49 | 50 | @abstractproperty 51 | def taxon_count(self) -> TShapeType: 52 | """Number of leaf taxa the tree is based on (``n``)""" 53 | pass 54 | 55 | @property 56 | def postorder_node_indices(self) -> TDataType: 57 | pass 58 | 59 | @abstractproperty 60 | def taxon_set(self) -> tp.Optional[TaxonSet]: 61 | pass 62 | 63 | 64 | __all__ = [BaseTreeTopology.__name__, AbstractTreeTopology.__name__] 65 | -------------------------------------------------------------------------------- /treeflow/tree/topology/numpy_topology_operations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import typing as tp 3 | 4 | 5 | def _get_child_indices_flat(parent_indices: np.ndarray) -> np.ndarray: 6 | """ 7 | Get parent indices 8 | 9 | Params 10 | ------ 11 | parent_indices 12 | """ 13 | node_count = parent_indices.shape[-1] + 1 14 | child_indices = np.full((node_count, 2), -1) 15 | current_child = np.zeros(node_count, dtype=int) 16 | for i in range(node_count - 1): 17 | parent = parent_indices[i] 18 | if ( 19 | current_child[parent] == 0 20 | or child_indices[parent, current_child[parent] - 1] < i 21 | ): 22 | child_indices[parent, current_child[parent]] = i 23 | else: # Ensure last axis sorted 24 | child_indices[parent, current_child[parent]] = child_indices[ 25 | parent, current_child[parent] - 1 26 | ] 27 | child_indices[parent, current_child[parent] - 1] = 1 28 | current_child[parent] += 1 29 | return child_indices 30 | 31 | 32 | get_child_indices: tp.Callable[[np.ndarray], np.ndarray] = np.vectorize( 33 | _get_child_indices_flat, otypes=[np.int32], signature="(m)->(n,2)" 34 | ) 35 | 36 | 37 | def _get_preorder_indices_flat(child_indices: np.ndarray): 38 | """ 39 | Get preorder indices 40 | 41 | Params 42 | ------ 43 | child_indices 44 | """ 45 | node_count = child_indices.shape[-2] 46 | 47 | def is_leaf(node_index): 48 | return child_indices[node_index, 0] == -1 49 | 50 | stack = np.zeros(node_count, dtype=int) 51 | stack[0] = len(child_indices) - 1 52 | stack_length = 1 53 | 54 | visited = np.zeros(node_count, dtype=int) 55 | visited_count = 0 56 | 57 | while stack_length: 58 | node_index = stack[stack_length - 1] 59 | stack_length -= 1 60 | if not is_leaf(node_index): 61 | for child_index in child_indices[node_index][::-1]: 62 | stack[stack_length] = child_index 63 | stack_length += 1 64 | visited[visited_count] = node_index 65 | visited_count += 1 66 | return visited 67 | 68 | 69 | get_preorder_indices: tp.Callable[[np.ndarray], np.ndarray] = np.vectorize( 70 | _get_preorder_indices_flat, otypes=[np.int32], signature="(m,2)->(m)" 71 | ) 72 | 73 | __all__ = ["get_child_indices", "get_preorder_indices"] 74 | -------------------------------------------------------------------------------- /treeflow/tree/topology/numpy_tree_topology.py: -------------------------------------------------------------------------------- 1 | from treeflow.tree import taxon_set 2 | import typing as tp 3 | import attr 4 | from treeflow.tree.topology.base_tree_topology import AbstractTreeTopology 5 | import numpy as np 6 | from treeflow.tree.taxon_set import TaxonSet 7 | from treeflow.tree.topology.numpy_topology_operations import ( 8 | get_child_indices, 9 | get_preorder_indices, 10 | ) 11 | 12 | 13 | @attr.attrs(auto_attribs=True) 14 | class NumpyTreeTopologyAttrs(AbstractTreeTopology[np.ndarray, int]): 15 | parent_indices: np.ndarray # Convenience type hint 16 | 17 | 18 | class NumpyTreeTopology(NumpyTreeTopologyAttrs): 19 | """ 20 | Class representing a bifurcating tree topology as a composition of integer 21 | NumPy arrays. 22 | 23 | For a phylogenetic tree with ``n`` taxa at the leaves, the representation 24 | maintains a labelling of the ``1n-1`` nodes with integer indices. The labelling 25 | convention is that the leaves are the first ``n`` indices and the root is at the 26 | last index (``2n-2``). 27 | """ 28 | 29 | def __init__( 30 | self, parent_indices: np.ndarray, taxon_set: tp.Optional[TaxonSet] = None 31 | ): 32 | super().__init__(parent_indices=parent_indices) 33 | self._taxon_set = taxon_set 34 | 35 | @property 36 | def taxon_count(self) -> int: 37 | return (self.parent_indices.shape[-1] + 2) // 2 38 | 39 | @property 40 | def node_count(self) -> int: 41 | return 2 * self.taxon_count - 1 42 | 43 | @property 44 | def postorder_node_indices(self) -> np.ndarray: 45 | return np.arange(self.taxon_count, 2 * self.taxon_count - 1) 46 | 47 | @property 48 | def child_indices(self) -> np.ndarray: 49 | return get_child_indices(self.parent_indices) 50 | 51 | @property 52 | def preorder_indices(self) -> np.ndarray: 53 | return get_preorder_indices(self.child_indices) 54 | 55 | @property 56 | def taxon_set(self) -> tp.Optional[TaxonSet]: 57 | return self._taxon_set 58 | 59 | # Methods to allow pickling 60 | def __getstate__(self): 61 | return (super().__getstate__(), self._taxon_set) 62 | 63 | def __setstate__(self, state): 64 | super().__setstate__(state[0]) 65 | self._taxon_set = state[1] 66 | 67 | 68 | __all__ = ["NumpyTreeTopology"] 69 | -------------------------------------------------------------------------------- /treeflow/tree/unrooted/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/christiaanjs/treeflow/c10dd306c5c54c9d3ddc3ea8cd69524cb9bc41fb/treeflow/tree/unrooted/__init__.py -------------------------------------------------------------------------------- /treeflow/tree/unrooted/base_unrooted_tree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import typing as tp 3 | from treeflow.tree.base_tree import AbstractTree 4 | from treeflow.tree.topology.base_tree_topology import ( 5 | AbstractTreeTopology, 6 | BaseTreeTopology, 7 | ) 8 | import attr 9 | from treeflow.tf_util import AttrsLengthMixin 10 | 11 | TDataType = tp.TypeVar("TDataType") 12 | TShapeType = tp.TypeVar("TShapeType") 13 | 14 | 15 | @attr.s(auto_attribs=True, slots=True) 16 | class BaseUnrootedTree( 17 | tp.Generic[TDataType], 18 | AttrsLengthMixin, 19 | ): 20 | topology: BaseTreeTopology[TDataType] 21 | branch_lengths: TDataType 22 | 23 | 24 | @attr.s(auto_attribs=True, slots=True) 25 | class AbstractUnrootedTree( 26 | BaseUnrootedTree, 27 | AbstractTree[TDataType, TShapeType], 28 | tp.Generic[TDataType, TShapeType], 29 | ): 30 | topology: AbstractTreeTopology[TDataType, TShapeType] 31 | 32 | def get_unrooted_tree(self) -> AbstractUnrootedTree: 33 | return self 34 | 35 | 36 | __all__ = [BaseUnrootedTree.__name__, AbstractUnrootedTree.__name__] 37 | -------------------------------------------------------------------------------- /treeflow/tree/unrooted/numpy_unrooted_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import attr 3 | from treeflow.tree.unrooted.base_unrooted_tree import AbstractUnrootedTree 4 | from treeflow.tree.topology.numpy_tree_topology import NumpyTreeTopology 5 | 6 | 7 | @attr.s(auto_attribs=True, slots=True) 8 | class NumpyUnrootedTree(AbstractUnrootedTree[np.ndarray, int]): 9 | topology: NumpyTreeTopology 10 | branch_lengths: np.ndarray 11 | -------------------------------------------------------------------------------- /treeflow/tree/unrooted/tensorflow_unrooted_tree.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import tensorflow as tf 3 | import attr 4 | from treeflow.tree.unrooted.base_unrooted_tree import AbstractUnrootedTree 5 | from treeflow.tree.topology.tensorflow_tree_topology import TensorflowTreeTopology 6 | 7 | 8 | @attr.s(auto_attribs=True, slots=True) 9 | class TensorflowUnrootedTree(AbstractUnrootedTree[tf.Tensor, tf.Tensor]): 10 | topology: TensorflowTreeTopology 11 | branch_lengths: tf.Tensor 12 | 13 | def with_branch_lengths(self, branch_lengths): 14 | return TensorflowUnrootedTree( 15 | topology=self.topology, branch_lengths=branch_lengths 16 | ) 17 | 18 | def __mul__(self, other: tf.Tensor) -> TensorflowUnrootedTree: 19 | return self.with_branch_lengths(self.branch_lengths * other) 20 | 21 | def __rmul__(self, other: tf.Tensor) -> TensorflowUnrootedTree: 22 | return self.__mul__(other) 23 | -------------------------------------------------------------------------------- /treeflow/vi/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.vi.fixed_topology_advi import fit_fixed_topology_variational_approximation 2 | from treeflow.vi.marginal_likelihood import * 3 | from treeflow.vi.optimizers import * 4 | -------------------------------------------------------------------------------- /treeflow/vi/convergence_criteria/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.vi.convergence_criteria.nonfinite import * 2 | -------------------------------------------------------------------------------- /treeflow/vi/convergence_criteria/nonfinite.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow_probability.python.optimizer.convergence_criteria import ( 3 | ConvergenceCriterion, 4 | ) 5 | 6 | 7 | def _any_nonfinite(x): 8 | nonfinite = tf.logical_not(tf.math.is_finite(x)) 9 | return tf.reduce_any(nonfinite) 10 | 11 | 12 | class NonfiniteConvergenceCriterion(ConvergenceCriterion): 13 | def __init__(self, name="NonfiniteConvergenceCriterion"): 14 | super().__init__(min_num_steps=0, name=name) 15 | 16 | def _bootstrap(self, loss, grads, parameters): 17 | return () 18 | 19 | def _one_step(self, step, loss, grads, parameters, auxiliary_state): 20 | loss_nonfinite = _any_nonfinite(loss) 21 | grads_nonfinite = [_any_nonfinite(x) for x in grads] 22 | parameters_nonfinite = [_any_nonfinite(x) for x in parameters] 23 | has_converged = tf.reduce_any( 24 | [loss_nonfinite] + grads_nonfinite + parameters_nonfinite 25 | ) 26 | return has_converged, auxiliary_state 27 | 28 | 29 | __all__ = ["NonfiniteConvergenceCriterion"] 30 | -------------------------------------------------------------------------------- /treeflow/vi/marginal_likelihood.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | import tensorflow as tf 3 | import numpy as np 4 | from tensorflow_probability.python.distributions import Distribution 5 | from tensorflow_probability.python.util import SeedStream 6 | import treeflow 7 | 8 | 9 | def estimate_log_ml_importance_sampling( 10 | model: Distribution, 11 | approx: Distribution, 12 | n_samples=100, 13 | approx_samples=None, 14 | return_std=False, 15 | vectorize_log_prob=True, 16 | seed=None, 17 | ) -> tp.Union[tf.Tensor, tp.Tuple[tf.Tensor, tf.Tensor]]: 18 | """ 19 | Estimate the log marginal likelihood using importance sampling 20 | This estimate can have high variance if the fit of the approximation 21 | is poor. 22 | Parameters 23 | ---------- 24 | model 25 | The (pinned) distribution representing the prior and likelihood 26 | approx 27 | A fitted variational approximation 28 | n_samples 29 | (Optional) The number of samples to use in the estimate (default 100) 30 | return_std 31 | (Optional) Whether to also return the estimated standard 32 | """ 33 | assert not ( 34 | (not vectorize_log_prob) and (not approx_samples is None) 35 | ), "If samples are provided then vectorised log prob much be used" 36 | 37 | if vectorize_log_prob: 38 | if approx_samples is None: 39 | approx_samples = approx.sample(n_samples, seed=seed) 40 | model_log_probs = model.unnormalized_log_prob(approx_samples) 41 | approx_log_probs = approx.log_prob(approx_samples) 42 | estimates = model_log_probs - approx_log_probs 43 | else: 44 | 45 | seed = SeedStream(seed, salt="ml_estimate") 46 | estimates = np.zeros((n_samples,), dtype=treeflow.DEFAULT_FLOAT_DTYPE_NP) 47 | 48 | @tf.function 49 | def estimate_fn(seed): 50 | sample = approx.sample(seed=seed) 51 | model_log_prob = model.unnormalized_log_prob(sample) 52 | approx_log_prob = approx.log_prob(sample) 53 | return model_log_prob - approx_log_prob 54 | 55 | for i in range(n_samples): 56 | estimates[i] = estimate_fn(seed()).numpy() 57 | 58 | n_samples_float = tf.cast(n_samples, estimates.dtype) 59 | ml_estimate = tf.math.reduce_logsumexp(estimates) - tf.math.log(n_samples_float) 60 | if return_std: 61 | std = tf.math.reduce_std(estimates) 62 | ml_estimate_std = std / tf.sqrt(n_samples_float) 63 | return (ml_estimate, ml_estimate_std) 64 | else: 65 | return ml_estimate 66 | -------------------------------------------------------------------------------- /treeflow/vi/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from treeflow.vi.optimizers.robust_optimizer import * 2 | -------------------------------------------------------------------------------- /treeflow/vi/optimizers/robust_optimizer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class RobustOptimizer: # Pseudo-optimizer 5 | def __init__(self, inner, max_retries=100): # TODO: Count number of failed steps 6 | self.inner = inner 7 | self.max_retries = max_retries 8 | self.retries = tf.Variable(0) 9 | 10 | def apply_gradients(self, grads_and_vars, name=None, **kwargs): 11 | """Apply gradients to variables. 12 | 13 | Args: 14 | grads_and_vars: List of (gradient, variable) pairs as returned by 15 | `compute_gradients()`. 16 | global_step: Optional `Variable` to increment by one after the 17 | variables have been updated. 18 | name: Optional name for the returned operation. Default to the 19 | name passed to the `Optimizer` constructor. 20 | """ 21 | grads_and_vars = list(grads_and_vars) 22 | grads = [grad for grad, var in grads_and_vars] 23 | any_nan = tf.reduce_any([tf.reduce_any(tf.math.is_nan(x)) for x in grads]) 24 | 25 | def nan_handler(): 26 | assertion = tf.assert_less(self.retries, self.max_retries) 27 | with tf.control_dependencies([assertion]): 28 | self.retries.assign_add(1) 29 | return tf.zeros( 30 | (), dtype=tf.int64 31 | ) # apply_gradients returns int64 iteration 32 | 33 | def update_handler(): 34 | self.retries.assign(0) 35 | return self.inner.apply_gradients(grads_and_vars, name=name, **kwargs) 36 | 37 | return tf.cond( 38 | any_nan, 39 | nan_handler, 40 | update_handler, 41 | ) 42 | 43 | 44 | __all__ = ["RobustOptimizer"] 45 | -------------------------------------------------------------------------------- /treeflow/vi/progress_bar.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from typing_extensions import Protocol 3 | from functools import partial 4 | import tqdm 5 | from tensorflow_probability.python.math.minimize import MinimizeTraceableQuantities 6 | 7 | 8 | class ProgressBarFunc(Protocol): 9 | def __call__(self, total: int, *args, **kwds) -> tqdm.tqdm: 10 | ... 11 | 12 | 13 | class ProgressBarTraceFunctionContextManager: 14 | def __init__(self, tqdm: tp.Optional[tqdm.tqdm], trace_fn: tp.Callable): 15 | self.tqdm = tqdm 16 | self.trace_fn = trace_fn 17 | 18 | def __enter__(self): 19 | if self.tqdm is not None: 20 | self.tqdm.__enter__() 21 | return self.trace_fn 22 | 23 | def __exit__(self, exc_type, exc_value, traceback): 24 | if self.tqdm is not None: 25 | self.tqdm.__exit__(exc_type, exc_value, traceback) 26 | 27 | 28 | def update_trace_fn( 29 | mtq: MinimizeTraceableQuantities, 30 | trace_fn: tp.Callable, 31 | tqdm_instance: tqdm.tqdm, 32 | update_step: int = 10, 33 | ): 34 | step = mtq.step 35 | if (step % update_step == 0 and tqdm_instance.n < step) or tqdm_instance.n == step: 36 | tqdm_instance.update(update_step) 37 | return trace_fn(mtq) 38 | 39 | 40 | def make_progress_bar_trace_fn( 41 | trace_fn: tp.Callable, 42 | num_steps: int, 43 | progress_bar: tp.Union[ProgressBarFunc, bool] = True, 44 | update_step: int = 10, 45 | ): 46 | 47 | total = num_steps 48 | if isinstance(progress_bar, bool): 49 | if progress_bar: 50 | tqdm_instance = tqdm.tqdm(total=total) 51 | else: 52 | tqdm_instance = None 53 | else: 54 | tqdm_instance = progress_bar(total=total) 55 | 56 | if tqdm_instance is None: 57 | wrapped_trace_fn = trace_fn 58 | else: 59 | wrapped_trace_fn = partial( 60 | update_trace_fn, 61 | trace_fn=trace_fn, 62 | tqdm_instance=tqdm_instance, 63 | update_step=update_step, 64 | ) 65 | 66 | return ProgressBarTraceFunctionContextManager(tqdm_instance, wrapped_trace_fn) 67 | -------------------------------------------------------------------------------- /treeflow/vi/util.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | from collections import namedtuple 3 | import tensorflow as tf 4 | from tensorflow_probability.python.math import MinimizeTraceableQuantities 5 | 6 | VIResults = namedtuple("VIResults", ("loss", "parameters")) 7 | 8 | 9 | def default_vi_trace_fn( 10 | traceable_quantities: MinimizeTraceableQuantities, 11 | variables_dict: tp.Dict[str, tf.Variable], 12 | ) -> VIResults: 13 | return VIResults( # TODO: Name parameters 14 | loss=traceable_quantities.loss, 15 | parameters=variables_dict, 16 | ) 17 | 18 | 19 | __all__ = ["VIResults", "default_vi_trace_fn"] 20 | --------------------------------------------------------------------------------