├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── generate_epiToPublicAndDate.ipynb ├── generate_epiToPublicAndDate.py ├── notebooks ├── NS_ratio.ipynb ├── characterize_predicted.ipynb ├── contrast-strains.ipynb ├── dissect_delta_logR.ipynb ├── escape_calc.ipynb ├── explore-gisaid-af.ipynb ├── explore-gisaid.ipynb ├── explore-jhu-time-series.ipynb ├── explore-mutation-annotated-tree.ipynb ├── explore-nextclade-counts.ipynb ├── explore-nextclade.ipynb ├── explore-nextstrain-open-data.ipynb ├── explore-owid-vaccinations.ipynb ├── explore-usher.ipynb ├── explore_pairs_of_lineages.ipynb ├── explore_pango_lineages.ipynb ├── grid_search.ipynb ├── ind_emergences_violin.ipynb ├── make_mutations_subset.ipynb ├── mini_mutrans.ipynb ├── mutrans-vary-coef-scale.ipynb ├── mutrans.ipynb ├── mutrans_backtesting-fullrun.ipynb ├── mutrans_backtesting.ipynb ├── mutrans_forecasts.ipynb ├── mutrans_gene.ipynb ├── mutrans_hack_7_2022.ipynb ├── mutrans_loo.ipynb ├── paper ├── preprocess-pairwise-allele-fractions.ipynb ├── rank_mutations.ipynb ├── results └── slop_mapping.ipynb ├── paper ├── .gitignore ├── BA.1_logistic_regression_all_US.jpg ├── Makefile ├── N_S_peak_ratio.png ├── N_S_ratio.png ├── README.md ├── Science.bst ├── accession_ids.txt.xz ├── backtesting │ ├── L1_error_barplot_England.png │ ├── L1_error_barplot_all.png │ ├── L1_error_barplot_other.png │ ├── L1_error_barplot_top100-1000.png │ ├── L1_error_barplot_top100.png │ ├── backtesting_day_248_early_prediction_england.png │ ├── backtesting_day_346_early_prediction_england.png │ ├── backtesting_day_360_early_prediction_england.png │ ├── backtesting_day_430.png │ ├── backtesting_day_444.png │ ├── backtesting_day_458.png │ ├── backtesting_day_472.png │ ├── backtesting_day_486.png │ ├── backtesting_day_500.png │ ├── backtesting_day_514.png │ ├── backtesting_day_514_early_prediction_england.png │ ├── backtesting_day_528.png │ ├── backtesting_day_528_early_prediction_england.png │ ├── backtesting_day_542.png │ ├── backtesting_day_556.png │ ├── backtesting_day_570.png │ ├── backtesting_day_584.png │ ├── backtesting_day_598.png │ ├── backtesting_day_612.png │ ├── backtesting_day_612_early_prediction_england.png │ ├── backtesting_day_626.png │ ├── backtesting_day_640.png │ ├── backtesting_day_640_early_prediction_england.png │ ├── backtesting_day_654.png │ ├── backtesting_day_668.png │ ├── backtesting_day_682.png │ ├── backtesting_day_696.png │ ├── backtesting_day_710.png │ ├── backtesting_day_710_early_prediction_england.png │ ├── backtesting_day_724.png │ ├── backtesting_day_724_early_prediction_england.png │ ├── backtesting_day_738.png │ ├── backtesting_day_738_early_prediction_england.png │ ├── backtesting_day_752.png │ ├── backtesting_day_752_early_prediction_england.png │ ├── backtesting_day_766.png │ ├── backtesting_day_766_early_prediction_Asia_Europe_Africa.png │ ├── backtesting_day_766_early_prediction_Massachusetts_for_heatmap.png │ ├── backtesting_day_766_early_prediction_Massachusetts_for_heatmap_nolegend.png │ ├── backtesting_day_766_early_prediction_USA_France_England_Brazil_Australia_Russia.png │ ├── backtesting_day_766_early_prediction_brazil_for_heatmap.png │ ├── backtesting_day_766_early_prediction_brazil_for_heatmap_nolegend.png │ ├── backtesting_day_766_early_prediction_england.png │ ├── backtesting_day_766_early_prediction_england_for_heatmap.png │ └── backtesting_day_766_early_prediction_england_for_heatmap_nolegend.png ├── binding_retained.png ├── clade_distribution.png ├── coef_scale_manhattan.png ├── coef_scale_pearson.png ├── coef_scale_scatter.png ├── coef_scale_volcano.png ├── convergence.png ├── convergence_loss.png ├── convergent_evolution.jpg ├── deep_scanning.png ├── delta_logR_breakdown.png ├── dobanno_2021_immunogenicity_N.png ├── forecast.png ├── formatted │ ├── Figure1.ai │ ├── Figure1.jpg │ ├── Figure2.ai │ ├── Figure2.jpg │ ├── Figure2_alt.ai │ └── Figure2_alt.jpg ├── gene_ratios.png ├── ind_emergences_violin.png ├── lineage_agreement.png ├── lineage_heterogeneity.png ├── lineage_prediction.png ├── main.bib ├── main.md ├── manhattan.png ├── manhattan_N.png ├── manhattan_N_coef_scale_0.05.png ├── manhattan_ORF1a.png ├── manhattan_ORF1a_coef_scale_0.05.png ├── manhattan_ORF1b.png ├── manhattan_ORF1b_coef_scale_0.05.png ├── manhattan_ORF3a.png ├── manhattan_ORF3a_coef_scale_0.05.png ├── manhattan_S.png ├── manhattan_S_coef_scale_0.05.png ├── moran.csv ├── multinomial_LR_vs_pyr0.jpg ├── mutation_agreement.png ├── mutation_europe_boxplot_rankby_s.png ├── mutation_europe_boxplot_rankby_t.png ├── mutation_scoring │ ├── by_gene_heatmap.png │ ├── plot_individual_mutation_significance.png │ ├── plot_individual_mutation_significance_M.png │ ├── plot_individual_mutation_significance_S.png │ ├── plot_individual_mutation_significance_S__K_to_N.png │ ├── plot_k_to_n_mutation_significance.png │ ├── plot_per_gene_aggregate_sign.png │ ├── plot_pval_vs_top_genes_E.png │ ├── plot_pval_vs_top_genes_N.png │ ├── plot_pval_vs_top_genes_ORF10.png │ ├── plot_pval_vs_top_genes_ORF1a.png │ ├── plot_pval_vs_top_genes_ORF1b.png │ ├── plot_pval_vs_top_genes_ORF3a.png │ ├── plot_pval_vs_top_genes_ORF6.png │ ├── plot_pval_vs_top_genes_ORF7b.png │ ├── plot_pval_vs_top_genes_ORF8.png │ ├── plot_pval_vs_top_genes_ORF9b.png │ ├── plot_pval_vs_top_genes_S.png │ ├── pvals_vs_top_genes.png │ └── top_substitutions.png ├── mutation_summaries.jpg ├── mutations.tsv ├── region_distribution.png ├── relative_growth_rate.jpg ├── schematic_overview.key ├── schematic_overview.png ├── scibib.bib ├── scicite.sty ├── scifile.tex ├── spectrum_transmissibility.jpg ├── strain_emergence.png ├── strain_europe_boxplot.png ├── strain_prevalence.png ├── strain_prevalence_data.csv ├── strains.tsv ├── supplement.tex ├── supplement_table_3_02_05_2022.jpeg ├── table2.jpeg ├── table_1.jpeg ├── table_1.tsv ├── table_1_02_05_2022.jpeg ├── table_2_02_05_2022.jpeg ├── vary_gene_elbo.png ├── vary_gene_likelihood.png ├── vary_gene_loss.png ├── vary_nsp_elbo.png ├── vary_nsp_likelihood.png ├── volcano.csv └── volcano.png ├── pyrocov ├── __init__.py ├── aa.py ├── align.py ├── distributions.py ├── external │ ├── __init__.py │ └── usher │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ └── parsimony_pb2.py ├── fasta.py ├── geo.py ├── growth.py ├── hashsubset.py ├── io.py ├── markov_tree.py ├── mutrans.py ├── mutrans_helpers.py ├── ops.py ├── pangolin.py ├── phylo.py ├── plot_additional_figs.R ├── plotting.py ├── sarscov2.py ├── sketch.cpp ├── sketch.py ├── softmax_tree.py ├── special.py ├── stats.py ├── strains.py ├── substitution.py ├── usher.py └── util.py ├── scripts ├── analyze_nextstrain.py ├── fix_columns.py ├── git_pull.py ├── moran.py ├── mutrans.py ├── preprocess_credits.py ├── preprocess_gisaid.py ├── preprocess_nextstrain.py ├── preprocess_usher.py ├── pull_gisaid.sh ├── pull_nextstrain.sh ├── pull_usher.sh ├── rank_mutations.py ├── run_backtesting.sh └── update_headers.py ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── conftest.py ├── test_distributions.py ├── test_io.py ├── test_markov_tree.py ├── test_ops.py ├── test_phylo.py ├── test_sarscov2.py ├── test_sketch.py ├── test_softmax_tree.py ├── test_strains.py ├── test_substitution.py └── test_usher.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | # Allows you to run this workflow manually from the Actions tab 10 | workflow_dispatch: 11 | 12 | env: 13 | CXX: g++-8 14 | CC: gcc-8 15 | 16 | jobs: 17 | lint: 18 | runs-on: ubuntu-20.04 19 | strategy: 20 | matrix: 21 | python-version: [3.7] 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip wheel setuptools 31 | pip install flake8 black isort>=5.0 mypy types-protobuf 32 | - name: Lint 33 | run: | 34 | make lint 35 | unit: 36 | runs-on: ubuntu-20.04 37 | needs: lint 38 | strategy: 39 | matrix: 40 | python-version: [3.7] 41 | steps: 42 | - uses: actions/checkout@v2 43 | - name: Set up Python ${{ matrix.python-version }} 44 | uses: actions/setup-python@v2 45 | with: 46 | python-version: ${{ matrix.python-version }} 47 | - name: Install dependencies 48 | run: | 49 | sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test 50 | sudo apt-get update 51 | sudo apt-get install gcc-8 g++-8 ninja-build 52 | python -m pip install --upgrade pip wheel setuptools 53 | pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 54 | pip install .[test] 55 | pip install coveralls 56 | pip freeze 57 | - name: Run unit tests 58 | run: | 59 | make test 60 | 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | nextalign 3 | nextclade 4 | nextclade_results/ 5 | /data 6 | /results 7 | /results.* 8 | *.swp 9 | temp.* 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## File organization 2 | 3 | - The root directory `/` contains scripts and jupyter notebooks. 4 | - All module code lives in the `/pyrocov/` package directory. 5 | As a rule of thumb, if you're importing a .py file an any other file, 6 | it should live in the `/pyrocov/` directory; if your .py file has an `if __name__ == "__main__"` check, then it should live in the root directory. 7 | - Notebooks should be cleared of data before committing. 8 | - The `/results/` directory (not under git control) contains large intermediate data 9 | including preprocessed input data and outputs from models and prediction. 10 | - The `/paper/` directory contains git-controlled output for the paper, namely 11 | plots, .tsv files, and .fasta files for sharing outside of the dev team. 12 | 13 | ## Committing code 14 | 15 | - The current policy is to allow pushing of code directly, without PRs. 16 | - If you would like code reviewed, please submit a PR and tag a reviewer. 17 | - Please run `make format` and `make lint` before committing code. 18 | - We'd like to resurrect `make test` but it is currently broken. 19 | - Each notebook has an owner (see git history); 20 | changes to notebooks by non-owners should be through pull request. 21 | This rule is intended to reduce difficult merge conflicts in jupyter notebooks. 22 | 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | ########################################################################### 4 | # installation 5 | 6 | install: FORCE 7 | pip install -e .[test] 8 | 9 | install-pangolin: 10 | conda install -y -c bioconda -c conda-forge -c defaults pangolin 11 | 12 | install-usher: 13 | conda install -y -c bioconda -c conda-forge -c defaults usher 14 | 15 | ########################################################################### 16 | # ci tasks 17 | 18 | lint: FORCE 19 | flake8 --extend-exclude=pyrocov/external 20 | black --extend-exclude='notebooks|pyrocov/external' --check . 21 | isort --check --skip=pyrocov/external . 22 | python scripts/update_headers.py --check 23 | mypy . --exclude=build/ 24 | 25 | format: FORCE 26 | black --extend-exclude='notebooks|pyrocov/external' . 27 | isort --skip=pyrocov/external . 28 | python scripts/update_headers.py 29 | 30 | test: lint FORCE 31 | python scripts/git_pull.py --no-update cov-lineages/pango-designation 32 | python scripts/git_pull.py --no-update cov-lineages/pangoLEARN 33 | python scripts/git_pull.py --no-update nextstrain/nextclade 34 | pytest -v -n auto test 35 | test -e results/columns.3000.pkl \ 36 | && python scripts/mutrans.py --test -n 2 -s 2 \ 37 | || echo skipping test 38 | 39 | ########################################################################### 40 | # Main processing workflows 41 | 42 | # The DO_NOT_UPDATE logic aims to avoid someone accidentally updating a frozen 43 | # results directory. 44 | update: FORCE 45 | ! test -f results/DO_NOT_UPDATE 46 | scripts/pull_nextstrain.sh 47 | scripts/pull_usher.sh 48 | python scripts/git_pull.py cov-lineages/pango-designation 49 | python scripts/git_pull.py cov-lineages/pangoLEARN 50 | python scripts/git_pull.py CSSEGISandData/COVID-19 51 | python scripts/git_pull.py nextstrain/nextclade 52 | echo "frozen" > results/DO_NOT_UPDATE 53 | 54 | preprocess: FORCE 55 | python scripts/preprocess_usher.py 56 | 57 | preprocess-gisaid: FORCE 58 | python scripts/preprocess_usher.py \ 59 | --tree-file-in results/gisaid/gisaidAndPublic.masked.pb.gz \ 60 | --gisaid-metadata-file-in results/gisaid/metadata_2022_*_*.tsv.gz 61 | 62 | analyze: FORCE 63 | python scripts/mutrans.py --vary-holdout 64 | python scripts/mutrans.py --vary-gene 65 | python scripts/mutrans.py --vary-nsp 66 | python scripts/mutrans.py --vary-leaves=9999 --num-steps=2001 67 | 68 | backtesting-piecewise: FORCE 69 | # Generates all the backtesting models piece by piece so that it can be run on a GPU enabled machine 70 | python scripts/mutrans.py --backtesting-max-day `seq -s, 150 14 220` --forecast-steps 12 71 | python scripts/mutrans.py --backtesting-max-day `seq -s, 220 14 500` --forecast-steps 12 72 | python scripts/mutrans.py --backtesting-max-day `seq -s, 500 14 625` --forecast-steps 12 73 | python scripts/mutrans.py --backtesting-max-day `seq -s, 626 14 700` --forecast-steps 12 74 | python scripts/mutrans.py --backtesting-max-day `seq -s, 710 14 766` --forecast-steps 12 75 | 76 | backtesting-nofeatures: FORCE 77 | # Generates all the backtesting models piece by piece so that it can be run on a GPU enabled machine 78 | python scripts/mutrans.py --backtesting-max-day `seq -s, 150 14 220` --forecast-steps 12 --model-type reparam-localinit-nofeatures 79 | python scripts/mutrans.py --backtesting-max-day `seq -s, 220 14 500` --forecast-steps 12 --model-type reparam-localinit-nofeatures 80 | python scripts/mutrans.py --backtesting-max-day `seq -s, 500 14 625` --forecast-steps 12 --model-type reparam-localinit-nofeatures 81 | python scripts/mutrans.py --backtesting-max-day `seq -s, 626 14 700` --forecast-steps 12 --model-type reparam-localinit-nofeatures 82 | python scripts/mutrans.py --backtesting-max-day `seq -s, 710 14 766` --forecast-steps 12 --model-type reparam-localinit-nofeatures 83 | 84 | backtesting-complete: FORCE 85 | # Run only after running backtesting-piecewise on a machine with > 500GB ram to aggregate results 86 | python scripts/mutrans.py --backtesting-max-day `seq -s, 150 14 766` --forecast-steps 12 87 | 88 | backtesting: FORCE 89 | # Maximum possible run in a GPU highmem machine 90 | python scripts/mutrans.py --backtesting-max-day `seq -s, 430 14 766` --forecast-steps 12 91 | 92 | backtesting-short: FORCE 93 | # For quick testing of backtesting code changes 94 | python scripts/mutrans.py --backtesting-max-day `seq -s, 500 14 700` --forecast-steps 12 95 | 96 | EXCLUDE='.*\.json$$|.*mutrans\.pt$$|.*temp\..*|.*\.[EI](gene|region)=.*\.pt$$|.*__(gene|region|lineage)__.*\.pt$$' 97 | 98 | push: FORCE 99 | gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M \ 100 | rsync -r -x $(EXCLUDE) \ 101 | $(shell readlink results)/ \ 102 | gs://pyro-cov/$(shell readlink results | grep -o 'results\.[-0-9]*')/ 103 | 104 | pull: FORCE 105 | gsutil -m rsync -r -x $(EXCLUDE) \ 106 | gs://pyro-cov/$(shell readlink results | grep -o 'results\.[-0-9]*')/ \ 107 | $(shell readlink results)/ 108 | 109 | ########################################################################### 110 | # TODO remove these user-specific targets 111 | 112 | data: 113 | ln -sf ~/Google\ Drive\ File\ Stream/Shared\ drives/Pyro\ CoV data 114 | 115 | ssh-cpu: 116 | gcloud compute ssh --project pyro-284215 --zone us-central1-c pyro-cov-fritzo-vm -- -AX -t 'cd pyro-cov ; bash --login' 117 | 118 | ssh-gpu: 119 | gcloud compute ssh --project pyro-284215 --zone us-central1-c pyro-fritzo-vm-gpu -- -AX -t 'cd pyro-cov ; bash --login' 120 | 121 | FORCE: 122 | 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Github Release](https://img.shields.io/github/v/release/broadinstitute/pyro-cov)](https://github.com/broadinstitute/pyro-cov/releases) 2 | [![DOI](https://zenodo.org/badge/292037402.svg)](https://zenodo.org/badge/latestdoi/292037402) 3 | 4 | # Pyro models for SARS-CoV-2 analysis 5 | 6 | ![Overview](paper/schematic_overview.png) 7 | 8 | Supporting material for the paper ["Analysis of 6.4 million SARS-CoV-2 genomes identifies mutations associated with fitness"](https://www.science.org/doi/10.1126/science.abm1208) ([medRxiv](https://www.medrxiv.org/content/10.1101/2021.09.07.21263228v2)). Figures and supplementary data for that paper are in the [paper/](paper/) directory. 9 | 10 | This is open source, but we are not intending to support code for use by outside groups. To use outputs of this model, we recommend ingesting the tables [strains.tsv](paper/strains.tsv) and [mutations.tsv](paper/mutations.tsv). 11 | 12 | ## Reproducing 13 | 14 | ### Install software 15 | 16 | Clone this repository: 17 | ```sh 18 | git clone git@github.com:broadinstitute/pyro-cov 19 | cd pyro-cov 20 | ``` 21 | 22 | Install this python package: 23 | ```py 24 | pip install -e . 25 | ``` 26 | 27 | ### Get access to GISAID data 28 | 29 | Work with GISAID to get a data agreement. 30 | Define the following environment variables: 31 | ``` 32 | GISAID_USERNAME 33 | GISAID_PASSWORD 34 | GISAID_FEED 35 | ``` 36 | For example my username is `fritz` and my gisaid feed is `broad2`. 37 | 38 | ### Download data 39 | This downloads data from GISAID and clones repos for other data sources. 40 | ```sh 41 | make update 42 | ``` 43 | 44 | ### Preprocess data 45 | 46 | This takes under an hour. 47 | Results are cached in the `results/` directory, so re-running on newly pulled data should be able to re-use alignment and PANGOlineage classification work. 48 | ```sh 49 | make preprocess 50 | ``` 51 | 52 | ### Analyze data 53 | ```sh 54 | make analyze 55 | ``` 56 | 57 | ### Generate plots and tables 58 | Plots and tables are generated by running various notebooks: 59 | - [mutrans.py](notebooks/mutrans.ipynb) 60 | - [mutrans_backtesting.py](notebooks/mutrans_backtesting.ipynb) 61 | - [mutrans_gene.ipynb](notebooks/mutrans_gene.ipynb) 62 | 63 | ## Citing 64 | 65 | If you use this software or predictions in the [paper](paper/) directory please consider citing: 66 | 67 | ``` 68 | @article {Obermeyer2021.09.07.21263228, 69 | author = {Obermeyer, Fritz and 70 | Schaffner, Stephen F. and 71 | Jankowiak, Martin and 72 | Barkas, Nikolaos and 73 | Pyle, Jesse D. and 74 | Park, Daniel J. and 75 | MacInnis, Bronwyn L. and 76 | Luban, Jeremy and 77 | Sabeti, Pardis C. and 78 | Lemieux, Jacob E.}, 79 | title = {Analysis of 2.1 million SARS-CoV-2 genomes identifies mutations associated with transmissibility}, 80 | elocation-id = {2021.09.07.21263228}, 81 | year = {2021}, 82 | doi = {10.1101/2021.09.07.21263228}, 83 | publisher = {Cold Spring Harbor Laboratory Press}, 84 | URL = {https://www.medrxiv.org/content/early/2021/09/13/2021.09.07.21263228}, 85 | eprint = {https://www.medrxiv.org/content/early/2021/09/13/2021.09.07.21263228.full.pdf}, 86 | journal = {medRxiv} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /generate_epiToPublicAndDate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2f544c93-83ef-455f-b3d0-b507d6440376", 6 | "metadata": {}, 7 | "source": [ 8 | "Preprocess files for parsing" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 4, 14 | "id": "7dc48887-b123-43bb-9994-321a4c47abee", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import pandas as pd" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": 6, 24 | "id": "5352238b-51d6-456f-99c7-1e125e2099b2", 25 | "metadata": {}, 26 | "outputs": [ 27 | { 28 | "name": "stderr", 29 | "output_type": "stream", 30 | "text": [ 31 | "/tmp/ipykernel_2032/3544606020.py:1: DtypeWarning: Columns (16) have mixed types. Specify dtype option on import or set low_memory=False.\n", 32 | " gisaid_meta = pd.read_csv(\"results/gisaid/metadata_2022_07_27.tsv.gz\", sep=\"\\t\")\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "gisaid_meta = pd.read_csv(\"results/gisaid/metadata_2022_07_27.tsv.gz\", sep=\"\\t\")" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 7, 43 | "id": "9b3447d7-13d6-4f76-a922-8f564259bd85", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "gisaid_meta[\"vname\"] = gisaid_meta[\"Virus name\"].str.replace(\"hCoV-19/\",\"\")\n", 48 | "gisaid_meta[\"vname2\"] = gisaid_meta[\"vname\"]" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 8, 54 | "id": "2b551436-24ed-4949-b0d2-ac863e64a5f0", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "epi_map = gisaid_meta[[\"Accession ID\", \"vname\", \"vname2\", \"Collection date\"]]" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "a20539c6-4843-4107-8e0c-f7e621103659", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "epi_map = epi_map.sort_values(by=\"Accession ID\", ascending = True)" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "66119639-6cad-4bd0-aa41-7d0d161c02f7", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "epi_map" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "c8564631-e7bc-4907-8790-1935b48df902", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "epi_map.to_csv(\"results/gisaid/epiToPublicAndDate.latest\", header=False, sep=\"\\t\", index=False)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "017c3a6d-5fbc-444d-80cd-5fd8ccc8324d", 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "?pd.to_csv" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "id": "5b63f7f3-3742-4145-a873-d21ee6ec6f70", 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "?pd.DataFrame.to_csv" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "6bc024e5-1f2c-4598-9943-6d6a77be2bd8", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | } 119 | ], 120 | "metadata": { 121 | "kernelspec": { 122 | "display_name": "Python 3 (ipykernel)", 123 | "language": "python", 124 | "name": "python3" 125 | }, 126 | "language_info": { 127 | "codemirror_mode": { 128 | "name": "ipython", 129 | "version": 3 130 | }, 131 | "file_extension": ".py", 132 | "mimetype": "text/x-python", 133 | "name": "python", 134 | "nbconvert_exporter": "python", 135 | "pygments_lexer": "ipython3", 136 | "version": "3.9.12" 137 | } 138 | }, 139 | "nbformat": 4, 140 | "nbformat_minor": 5 141 | } 142 | -------------------------------------------------------------------------------- /generate_epiToPublicAndDate.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | #!/usr/bin/env python 5 | # coding: utf-8 6 | 7 | # Preprocess files for parsing 8 | 9 | # In[1]: 10 | 11 | 12 | import pandas as pd 13 | 14 | # In[2]: 15 | 16 | 17 | gisaid_meta = pd.read_csv("results/gisaid/metadata_2022_08_08.tsv.gz", sep="\t") 18 | 19 | 20 | # In[3]: 21 | 22 | 23 | gisaid_meta["vname"] = gisaid_meta["Virus name"].str.replace("hCoV-19/", "") 24 | gisaid_meta["vname2"] = gisaid_meta["vname"] 25 | 26 | 27 | # In[4]: 28 | 29 | 30 | epi_map = gisaid_meta[["Accession ID", "vname", "vname2", "Collection date"]] 31 | 32 | 33 | # In[5]: 34 | 35 | 36 | epi_map = epi_map.sort_values(by="Accession ID", ascending=True) 37 | 38 | 39 | # In[6]: 40 | 41 | 42 | epi_map 43 | 44 | 45 | # In[7]: 46 | 47 | 48 | epi_map.to_csv( 49 | "results/gisaid/epiToPublicAndDate.latest", header=False, sep="\t", index=False 50 | ) 51 | 52 | 53 | # In[8]: 54 | 55 | 56 | # get_ipython().run_line_magic('pinfo', 'pd.to_csv') 57 | 58 | 59 | 60 | # In[ ]: 61 | -------------------------------------------------------------------------------- /notebooks/contrast-strains.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## What are the salient differences between strains?\n", 8 | "\n", 9 | "This notebook addresses a [question of Tom Wenseleers](https://twitter.com/TWenseleers/status/1438780125479329792) about the salient differences between two strains, say between B.1.617.1 and the similar B.1.617.2 and B.1.617.3. You should be able to run this notebook merely after git cloning, to explore other salient differences." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "First load the precomputed data." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd\n", 26 | "strains_df = pd.read_csv(\"paper/strains.tsv\", sep=\"\\t\")\n", 27 | "mutations_df = pd.read_csv(\"paper/mutations.tsv\", sep=\"\\t\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "Convert to dictionaries for easier use." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 2, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "mutations_per_strain = {\n", 44 | " strain: frozenset(mutations.split(\",\"))\n", 45 | " for strain, mutations in zip(strains_df[\"strain\"], strains_df[\"mutations\"])\n", 46 | "}" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "effect_size = dict(zip(mutations_df[\"mutation\"], mutations_df[\"Δ log R\"]))" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "Create a helper to explore pairwise differences." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "def print_diff(strain1, strain2, max_results=10):\n", 72 | " mutations1 = mutations_per_strain[strain1]\n", 73 | " mutations2 = mutations_per_strain[strain2]\n", 74 | " diff = [(m, effect_size[m]) for m in mutations1 ^ mutations2]\n", 75 | " diff.sort(key=lambda me: -abs(me[1]))\n", 76 | " print(f\"{strain1} versus {strain2}\")\n", 77 | " print(\"AA mutation Δ log R Present in strain\")\n", 78 | " print(\"--------------------------------------------\")\n", 79 | " for m, e in diff[:max_results]:\n", 80 | " strain = strain1 if m in mutations1 else strain2\n", 81 | " print(f\"{m: <15s} {e: <10.3g} {strain}\")" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "Examine some example sequences." 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "B.1.617.2 versus B.1.617.1\n", 101 | "AA mutation Δ log R Present in strain\n", 102 | "--------------------------------------------\n", 103 | "ORF8:L84S 0.235 B.1.617.1\n", 104 | "ORF1a:S318L 0.206 B.1.617.1\n", 105 | "ORF1a:G392D 0.111 B.1.617.1\n", 106 | "ORF1b:P314F 0.11 B.1.617.1\n", 107 | "N:A220V 0.0752 B.1.617.2\n", 108 | "S:E484Q 0.0703 B.1.617.1\n", 109 | "S:A222V 0.0677 B.1.617.2\n", 110 | "ORF1a:M585V -0.0624 B.1.617.2\n", 111 | "ORF1a:S2535L 0.0612 B.1.617.1\n", 112 | "N:S194L 0.0599 B.1.617.1\n" 113 | ] 114 | } 115 | ], 116 | "source": [ 117 | "print_diff(\"B.1.617.2\", \"B.1.617.1\")" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 6, 123 | "metadata": {}, 124 | "outputs": [ 125 | { 126 | "name": "stdout", 127 | "output_type": "stream", 128 | "text": [ 129 | "B.1.617.2 versus B.1.617.3\n", 130 | "AA mutation Δ log R Present in strain\n", 131 | "--------------------------------------------\n", 132 | "ORF1a:K680N 0.415 B.1.617.3\n", 133 | "N:A220V 0.0752 B.1.617.2\n", 134 | "S:E484Q 0.0703 B.1.617.3\n", 135 | "S:A222V 0.0677 B.1.617.2\n", 136 | "ORF1a:M585V -0.0624 B.1.617.2\n", 137 | "N:S187L 0.0596 B.1.617.3\n", 138 | "ORF1a:D2980N 0.0498 B.1.617.2\n", 139 | "ORF1a:M3655I 0.0413 B.1.617.3\n", 140 | "ORF1a:P309L 0.0409 B.1.617.2\n", 141 | "ORF1a:S3675- 0.0409 B.1.617.3\n" 142 | ] 143 | } 144 | ], 145 | "source": [ 146 | "print_diff(\"B.1.617.2\", \"B.1.617.3\")" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.7.0" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 4 178 | } 179 | -------------------------------------------------------------------------------- /notebooks/explore-nextclade.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "import pandas as pd" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "with open(\"results/gisaid.subset.json\", \"r\") as f:\n", 20 | " data = json.load(f)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "data[0].keys()" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "print(data[0]['seqName'])\n", 39 | "print(data[0]['alignmentScore'])" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "data[0]['aaSubstitutions']" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "data[0]['aaDeletions']" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "df = pd.read_csv(\"results/gisaid.subset.tsv\", sep=\"\\t\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "df" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "df.columns" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "df[\"aaSubstitutions\"].tolist()" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": null, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "df[\"aaDeletions\"]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "df[\"aaDeletions\"][0]" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "len(df)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "len(df.columns)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [] 138 | } 139 | ], 140 | "metadata": { 141 | "kernelspec": { 142 | "display_name": "Python 3 (ipykernel)", 143 | "language": "python", 144 | "name": "python3" 145 | }, 146 | "language_info": { 147 | "codemirror_mode": { 148 | "name": "ipython", 149 | "version": 3 150 | }, 151 | "file_extension": ".py", 152 | "mimetype": "text/x-python", 153 | "name": "python", 154 | "nbconvert_exporter": "python", 155 | "pygments_lexer": "ipython3", 156 | "version": "3.9.1" 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 4 161 | } 162 | -------------------------------------------------------------------------------- /notebooks/explore-owid-vaccinations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import datetime\n", 11 | "import pandas as pd\n", 12 | "import torch\n", 13 | "import matplotlib.pyplot as plt" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "path = os.path.expanduser(\"~/github/owid/covid-19-data/public/data/vaccinations/vaccinations.csv\")\n", 23 | "df = pd.read_csv(path, header=0)\n", 24 | "df" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "portion = df[\"total_vaccinations_per_hundred\"].to_numpy() / 100\n", 34 | "portion" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "locations = df[\"location\"].to_list()\n", 44 | "location_id = sorted(set(locations))\n", 45 | "location_id = {name: i for i, name in enumerate(location_id)}\n", 46 | "locations = torch.tensor([location_id[n] for n in locations], dtype=torch.long)" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "def parse_date(s):\n", 56 | " return datetime.datetime.strptime(s, \"%Y-%m-%d\")\n", 57 | "\n", 58 | "start_date = parse_date(\"2020-12-01\")\n", 59 | "dates = torch.tensor([(parse_date(d) - start_date).days\n", 60 | " for d in df[\"date\"]])\n", 61 | "assert dates.min() >= 0" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "T = int(1 + dates.max())\n", 71 | "R = len(location_id)\n", 72 | "dense_portion = torch" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "plt.scatter(dates, portion, s=10, alpha=0.5);" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [] 90 | } 91 | ], 92 | "metadata": { 93 | "kernelspec": { 94 | "display_name": "Python 3", 95 | "language": "python", 96 | "name": "python3" 97 | }, 98 | "language_info": { 99 | "codemirror_mode": { 100 | "name": "ipython", 101 | "version": 3 102 | }, 103 | "file_extension": ".py", 104 | "mimetype": "text/x-python", 105 | "name": "python", 106 | "nbconvert_exporter": "python", 107 | "pygments_lexer": "ipython3", 108 | "version": "3.6.12" 109 | } 110 | }, 111 | "nbformat": 4, 112 | "nbformat_minor": 4 113 | } 114 | -------------------------------------------------------------------------------- /notebooks/explore_pairs_of_lineages.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Compare pairs of lineages w.r.t. mutational profiles and determinants of transmissibility" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import pickle\n", 17 | "import numpy as np\n", 18 | "import matplotlib\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "import torch\n", 21 | "from pyrocov import pangolin\n", 22 | "import pandas as pd\n", 23 | "\n", 24 | "matplotlib.rcParams[\"figure.dpi\"] = 200\n", 25 | "matplotlib.rcParams[\"axes.edgecolor\"] = \"gray\"\n", 26 | "matplotlib.rcParams[\"savefig.bbox\"] = \"tight\"\n", 27 | "matplotlib.rcParams['font.family'] = 'sans-serif'\n", 28 | "matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "name": "stdout", 38 | "output_type": "stream", 39 | "text": [ 40 | "dict_keys(['location_id', 'mutations', 'weekly_clades', 'features', 'lineage_id', 'lineage_id_inv', 'time_shift'])\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "dataset = torch.load(\"results/mutrans.data.single.None.pt\", map_location=\"cpu\")\n", 46 | "print(dataset.keys())\n", 47 | "locals().update(dataset)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 3, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "features = dataset['features']\n", 57 | "coefficients = pd.read_csv(\"paper/mutations.tsv\", sep=\"\\t\", index_col=1)\n", 58 | "coefficients = coefficients.loc[dataset['mutations']].copy()\n", 59 | "feature_names = coefficients.index.values.tolist()\n", 60 | "\n", 61 | "lineage_id = {name: i for i, name in enumerate(lineage_id_inv)}\n", 62 | "lineage_id_inv = dataset['lineage_id_inv']\n", 63 | "\n", 64 | "deltaR = coefficients['Δ log R'].values\n", 65 | "zscore = coefficients['mean/stddev'].values" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 4, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "##########################################\n", 75 | "### select pair of lineages to compare ###\n", 76 | "##########################################\n", 77 | "A, B = 'B.1.617.2', 'B.1'\n", 78 | "#A, B = 'B.1.1.7', 'B.1.1'\n", 79 | "#A, B = 'B.1.427', 'B.1'\n", 80 | "#A, B = 'B.1.351', 'B.1'\n", 81 | "#A, B = 'P.1', 'B.1.1'\n", 82 | "#A, B = 'AY.2', 'B.1.617.2'\n", 83 | "\n", 84 | "A_id, B_id = lineage_id[A], lineage_id[B]\n", 85 | "A_feat, B_feat = features[A_id].numpy(), features[B_id].numpy()\n", 86 | "\n", 87 | "delta_cov = A_feat - B_feat\n", 88 | "delta_cov_abs = np.fabs(A_feat - B_feat)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "deltaR_cutoff 0.042476400000000004\n", 101 | "ORF1a:P2287S \t deltaR: 0.044 zscore: 253.71 \t\t delta_feature: 0.80\n", 102 | "ORF1a:T3255I \t deltaR: 0.045 zscore: 240.18 \t\t delta_feature: 0.80\n", 103 | "S:L452R \t deltaR: 0.048 zscore: 244.88 \t\t delta_feature: 0.98\n", 104 | "S:P681R \t deltaR: 0.051 zscore: 300.61 \t\t delta_feature: 0.97\n", 105 | "\n", 106 | " B.1.617.2 over B.1\n", 107 | "ORF1a:P2287S, ORF1a:T3255I, S:L452R, S:P681R, " 108 | ] 109 | } 110 | ], 111 | "source": [ 112 | "# look at top 100 mutations w.r.t. effect size\n", 113 | "deltaR_cutoff = np.fabs(deltaR)[np.argsort(np.fabs(deltaR))[-100]]\n", 114 | "print(\"deltaR_cutoff\", deltaR_cutoff)\n", 115 | "\n", 116 | "selected_features = []\n", 117 | "\n", 118 | "for i, name in enumerate(feature_names):\n", 119 | " if len(name) <= 6:\n", 120 | " name = name + \" \"\n", 121 | " dR = deltaR[i]\n", 122 | " dC = delta_cov[i]\n", 123 | " z = zscore[i]\n", 124 | " if dR > deltaR_cutoff and np.fabs(dC) > 0.5:\n", 125 | " selected_features.append(name)\n", 126 | " print(\"{} \\t deltaR: {:.3f} zscore: {:.2f} \\t\\t delta_feature: {:.2f}\".format(name, dR, z, dC))\n", 127 | "\n", 128 | "print(\"\\n\", A, \"over\" ,B)\n", 129 | "for s in selected_features:\n", 130 | " print(s + \", \", end='')" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [ 138 | { 139 | "name": "stdout", 140 | "output_type": "stream", 141 | "text": [ 142 | "M:I82T \t deltaR: 0.038 zscore: 230.30 \t\t delta_feature: 1.00\n", 143 | "N:D63G \t deltaR: 0.028 zscore: 201.31 \t\t delta_feature: 0.99\n", 144 | "ORF1a:P2287S \t deltaR: 0.044 zscore: 253.71 \t\t delta_feature: 0.80\n", 145 | "ORF1a:T3255I \t deltaR: 0.045 zscore: 240.18 \t\t delta_feature: 0.80\n", 146 | "ORF1b:G662S \t deltaR: 0.027 zscore: 228.74 \t\t delta_feature: 0.98\n", 147 | "ORF1b:P1000L \t deltaR: 0.035 zscore: 211.10 \t\t delta_feature: 0.98\n", 148 | "S:D950N \t deltaR: 0.036 zscore: 231.28 \t\t delta_feature: 0.98\n", 149 | "S:E156- \t deltaR: 0.028 zscore: 203.85 \t\t delta_feature: 0.95\n", 150 | "S:L452R \t deltaR: 0.048 zscore: 244.88 \t\t delta_feature: 0.98\n", 151 | "S:P681R \t deltaR: 0.051 zscore: 300.61 \t\t delta_feature: 0.97\n", 152 | "\n", 153 | " B.1.617.2 over B.1\n", 154 | "M:I82T , N:D63G , ORF1a:P2287S, ORF1a:T3255I, ORF1b:G662S, ORF1b:P1000L, S:D950N, S:E156-, S:L452R, S:P681R, " 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "selected_features = []\n", 160 | "\n", 161 | "# look at large z-score mutations (i.e. increase growth rate)\n", 162 | "for i, name in enumerate(feature_names):\n", 163 | " if len(name) <= 6:\n", 164 | " name = name + \" \"\n", 165 | " dR = deltaR[i]\n", 166 | " dC = delta_cov[i]\n", 167 | " z = zscore[i]\n", 168 | " if z > 200.0 and np.fabs(dC) > 0.5:\n", 169 | " selected_features.append(name)\n", 170 | " print(\"{} \\t deltaR: {:.3f} zscore: {:.2f} \\t\\t delta_feature: {:.2f}\".format(name, dR, z, dC))\n", 171 | "\n", 172 | "print(\"\\n\", A, \"over\" ,B)\n", 173 | "for s in selected_features:\n", 174 | " print(s + \", \", end='')" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 7, 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "def findfeat(s):\n", 184 | " for i, n in enumerate(feature_names):\n", 185 | " if n==s:\n", 186 | " return i\n", 187 | " return -1" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 8, 193 | "metadata": {}, 194 | "outputs": [ 195 | { 196 | "name": "stdout", 197 | "output_type": "stream", 198 | "text": [ 199 | "0.0054945056\n", 200 | "0.96954316\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "print(features[lineage_id['B.1']].numpy()[findfeat('S:H69-')])\n", 206 | "print(features[lineage_id['B.1.1.7']].numpy()[findfeat('S:H69-')])" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 9, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "text/plain": [ 217 | "array([[1.73514e-02, 2.10334e+02]])" 218 | ] 219 | }, 220 | "execution_count": 9, 221 | "metadata": {}, 222 | "output_type": "execute_result" 223 | } 224 | ], 225 | "source": [ 226 | "coefficients[coefficients.index == 'S:H69-'][['Δ log R', 'mean/stddev']].values" 227 | ] 228 | } 229 | ], 230 | "metadata": { 231 | "kernelspec": { 232 | "display_name": "Python 3", 233 | "language": "python", 234 | "name": "python3" 235 | }, 236 | "language_info": { 237 | "codemirror_mode": { 238 | "name": "ipython", 239 | "version": 3 240 | }, 241 | "file_extension": ".py", 242 | "mimetype": "text/x-python", 243 | "name": "python", 244 | "nbconvert_exporter": "python", 245 | "pygments_lexer": "ipython3", 246 | "version": "3.8.3" 247 | } 248 | }, 249 | "nbformat": 4, 250 | "nbformat_minor": 4 251 | } 252 | -------------------------------------------------------------------------------- /notebooks/explore_pango_lineages.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from pyrocov import pangolin" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "for k, v in sorted(pangolin.PANGOLIN_ALIASES.items()):\n", 19 | " print(f\"{k}\\t{v}\")" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "pangolin.update_aliases()" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [] 37 | } 38 | ], 39 | "metadata": { 40 | "kernelspec": { 41 | "display_name": "Python 3", 42 | "language": "python", 43 | "name": "python3" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.6.12" 56 | } 57 | }, 58 | "nbformat": 4, 59 | "nbformat_minor": 4 60 | } 61 | -------------------------------------------------------------------------------- /notebooks/grid_search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Analyzing results of grid search\n", 8 | "\n", 9 | "This notebook assumes you've downloaded data and run a grid search experiment\n", 10 | "```sh\n", 11 | "make update # many hours\n", 12 | "python mutrans.py --grid-search # many hours\n", 13 | "```" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import math\n", 23 | "import re\n", 24 | "import numpy as np\n", 25 | "import pandas as pd\n", 26 | "import matplotlib\n", 27 | "import matplotlib.pyplot as plt\n", 28 | "from pyrocov.util import pearson_correlation\n", 29 | "from pyrocov.plotting import force_apart\n", 30 | "\n", 31 | "matplotlib.rcParams[\"figure.dpi\"] = 200" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": { 38 | "scrolled": false 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "df = pd.read_csv(\"results/grid_search.tsv\", sep=\"\\t\")\n", 43 | "df = df.fillna(\"\")\n", 44 | "df" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "df.columns" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "scrolled": false 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "model_type = df[\"model_type\"].to_list()\n", 65 | "cond_data = df[\"cond_data\"].to_list()\n", 66 | "mutation_corr = df[\"mutation_corr\"].to_numpy()\n", 67 | "mutation_error = df[\"mutation_rmse\"].to_numpy() / df[\"mutation_stddev\"].to_numpy()\n", 68 | "mae_pred = df[\"England B.1.1.7 MAE\"].to_numpy()\n", 69 | "\n", 70 | "loss = df[\"loss\"].to_numpy()\n", 71 | "min_loss, max_loss = loss.min(), loss.max()\n", 72 | "assert (loss > 0).all(), \"you'll need to switch to symlog or sth\"\n", 73 | "loss = np.log(loss)\n", 74 | "loss -= loss.min()\n", 75 | "loss /= loss.max()\n", 76 | "R_alpha = df[\"R(B.1.1.7)/R(A)\"].to_numpy()\n", 77 | "R_delta = df[\"R(B.1.617.2)/R(A)\"].to_numpy()\n", 78 | "\n", 79 | "def plot_concordance(filenames=[], colorby=\"R\"):\n", 80 | " legend = {}\n", 81 | " def abbreviate_param(match):\n", 82 | " k = match.group()[:-1]\n", 83 | " v = k[0].upper()\n", 84 | " legend[v] = k\n", 85 | " return v\n", 86 | " def abbreviate_sample(match):\n", 87 | " k = match.group()[:-1]\n", 88 | " v = k[0]\n", 89 | " legend[v] = k\n", 90 | " return v + \"꞊\"\n", 91 | " fig, axes = plt.subplots(2, figsize=(8, 12))\n", 92 | " for ax, X, Y, xlabel, ylabel in zip(\n", 93 | " axes, [mutation_error, R_alpha], [mae_pred, R_delta],\n", 94 | " [\n", 95 | " # \"Pearson correlation of mutaitons\",\n", 96 | " \"Cross-validation error of mutation coefficients (lower is better)\",\n", 97 | " \"R(α) / R(A)\"],\n", 98 | " [\"England α portion MAE (lower is better)\", \"R(δ) / R(A)\"]\n", 99 | " ):\n", 100 | " ax.scatter(X, Y, 30, loss, lw=0, alpha=0.8, cmap=\"coolwarm\")\n", 101 | " ax.set_xlabel(xlabel)\n", 102 | " ax.set_ylabel(ylabel)\n", 103 | " \n", 104 | " X_, Y_ = force_apart(X, Y, stepsize=2)\n", 105 | " assert X_.dim() == 1\n", 106 | " X_X = []\n", 107 | " Y_Y = []\n", 108 | " for x_, x, y_, y in zip(X_, X, Y_, Y):\n", 109 | " X_X.extend([float(x_), float(x), None])\n", 110 | " Y_Y.extend([float(y_), float(y), None])\n", 111 | " ax.plot(X_X, Y_Y, \"k-\", lw=0.5, alpha=0.5, zorder=-10)\n", 112 | " for x, y, mt, cd, l in zip(X_, Y_, model_type, cond_data, loss):\n", 113 | " name = f\"{mt}-{cd}\"\n", 114 | " name = re.sub(\"[a-z_]+-\", abbreviate_param, name)\n", 115 | " name = re.sub(\"[a-z_]+=\", abbreviate_sample, name)\n", 116 | " name = name.replace(\"-\", \"\")\n", 117 | " ax.text(x, y, name, fontsize=7, va=\"center\", alpha=1 - 0.666 * l)\n", 118 | " \n", 119 | " axes[0].set_xscale(\"log\")\n", 120 | " axes[0].set_yscale(\"log\")\n", 121 | " axes[0].plot([], [], \"bo\", markeredgewidth=0, markersize=5, alpha=0.5,\n", 122 | " label=f\"loss={min_loss:0.2g} (better)\")\n", 123 | " axes[0].plot([], [], \"ro\", markeredgewidth=0, markersize=5, alpha=0.5,\n", 124 | " label=f\"loss={max_loss:0.2g} (worse)\")\n", 125 | " for k, v in sorted(legend.items()):\n", 126 | " axes[0].plot([], [], \"wo\", label=f\"{k} = {v}\")\n", 127 | " axes[0].legend(loc=\"upper right\", fontsize=\"small\")\n", 128 | " min_max = [max(X.min(), Y.min()), min(X.max(), Y.max())]\n", 129 | " axes[1].plot(min_max, min_max, \"k--\", alpha=0.2, zorder=-10)\n", 130 | " plt.subplots_adjust(hspace=0.15)\n", 131 | " for filename in filenames:\n", 132 | " plt.savefig(filename)\n", 133 | " \n", 134 | "plot_concordance([\"paper/grid_search.png\"])" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": null, 140 | "metadata": { 141 | "scrolled": false 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "import torch\n", 146 | "grid = torch.load(\"results/mutrans.grid.pt\")\n", 147 | "\n", 148 | "def plot_mutation_agreements(grid):\n", 149 | " fig, axes = plt.subplots(len(grid), 3, figsize=(8, 1 + 3 * len(grid)))\n", 150 | " for axe, (name, holdouts) in zip(axes, sorted(grid.items())):\n", 151 | " (name0, fit0), (name1, fit1), (name2, fit2) = holdouts.items()\n", 152 | " pairs = [\n", 153 | " [(name0, fit0), (name1, fit1)],\n", 154 | " [(name0, fit0), (name2, fit2)],\n", 155 | " [(name1, fit1), (name2, fit2)],\n", 156 | " ]\n", 157 | " means = [v[\"coef\"] * 0.01 for v in holdouts.values()]\n", 158 | " x0 = min(mean.min().item() for mean in means)\n", 159 | " x1 = max(mean.max().item() for mean in means)\n", 160 | " lb = 1.05 * x0 - 0.05 * x1\n", 161 | " ub = 1.05 * x1 - 0.05 * x0\n", 162 | " axe[1].set_title(str(name))\n", 163 | " axe[0].set_ylabel(str(name).replace(\"-\", \"\\n\").replace(\",\", \"\\n\"), fontsize=8)\n", 164 | " for ax, ((name1, fit1), (name2, fit2)) in zip(axe, pairs):\n", 165 | " mutations = sorted(set(fit1[\"mutations\"]) & set(fit2[\"mutations\"]))\n", 166 | " means = []\n", 167 | " for fit in (fit1, fit2):\n", 168 | " m_to_i = {m: i for i, m in enumerate(fit[\"mutations\"])}\n", 169 | " idx = torch.tensor([m_to_i[m] for m in mutations])\n", 170 | " means.append(fit[\"coef\"])\n", 171 | " ax.plot([lb, ub], [lb, ub], 'k--', alpha=0.3, zorder=-100)\n", 172 | " ax.scatter(means[1].numpy(), means[0].numpy(), 30, alpha=0.3, lw=0, color=\"darkred\")\n", 173 | " ax.axis(\"equal\")\n", 174 | " ax.set_title(\"ρ = {:0.2g}\".format(pearson_correlation(means[0], means[1])))\n", 175 | "plot_mutation_agreements(grid)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "## Debugging plotting code" 183 | ] 184 | }, 185 | { 186 | "cell_type": "code", 187 | "execution_count": null, 188 | "metadata": { 189 | "scrolled": false 190 | }, 191 | "outputs": [], 192 | "source": [ 193 | "from pyrocov.plotting import force_apart\n", 194 | "torch.manual_seed(1234567890)\n", 195 | "X, Y = torch.randn(2, 200)\n", 196 | "X_, Y_ = force_apart(X, Y)\n", 197 | "plt.plot(X, Y, \"ko\")\n", 198 | "for i in range(8):\n", 199 | " plt.plot(X_ + i / 20, Y_, \"r.\")\n", 200 | "X_X = []\n", 201 | "Y_Y = []\n", 202 | "for x_, x, y_, y in zip(X_, X, Y_, Y):\n", 203 | " X_X.extend([float(x_), float(x), None])\n", 204 | " Y_Y.extend([float(y_), float(y), None])\n", 205 | "plt.plot(X_X, Y_Y, \"k-\", lw=0.5, alpha=0.5, zorder=-10);" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [] 214 | } 215 | ], 216 | "metadata": { 217 | "kernelspec": { 218 | "display_name": "Python 3", 219 | "language": "python", 220 | "name": "python3" 221 | }, 222 | "language_info": { 223 | "codemirror_mode": { 224 | "name": "ipython", 225 | "version": 3 226 | }, 227 | "file_extension": ".py", 228 | "mimetype": "text/x-python", 229 | "name": "python", 230 | "nbconvert_exporter": "python", 231 | "pygments_lexer": "ipython3", 232 | "version": "3.6.12" 233 | } 234 | }, 235 | "nbformat": 4, 236 | "nbformat_minor": 4 237 | } 238 | -------------------------------------------------------------------------------- /notebooks/mutrans_loo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Predictive accuracy of mutrans model on new lineages\n", 8 | "\n", 9 | "This notebook assumes you have run\n", 10 | "```sh\n", 11 | "make update\n", 12 | "make preprocess\n", 13 | "python mutrans.py --vary-leaves\n", 14 | "```" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import pickle\n", 24 | "import numpy as np\n", 25 | "import matplotlib\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import torch\n", 28 | "from pyrocov import pangolin\n", 29 | "from pyrocov.util import pearson_correlation, quotient_central_moments\n", 30 | "\n", 31 | "matplotlib.rcParams[\"figure.dpi\"] = 200\n", 32 | "matplotlib.rcParams[\"axes.edgecolor\"] = \"gray\"\n", 33 | "matplotlib.rcParams[\"figure.facecolor\"] = \"white\"\n", 34 | "matplotlib.rcParams[\"savefig.bbox\"] = \"tight\"\n", 35 | "matplotlib.rcParams['font.family'] = 'sans-serif'\n", 36 | "matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "dataset = torch.load(\"results/mutrans.data.single.3000.1.50.None.pt\", map_location=\"cpu\")\n", 46 | "print(dataset.keys())\n", 47 | "locals().update(dataset)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "lineage_id = dataset[\"lineage_id\"]\n", 57 | "clade_id_to_lineage_id = dataset[\"clade_id_to_lineage_id\"]" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "loo = torch.load(\"results/mutrans.vary_leaves.pt\", map_location=\"cpu\")\n", 67 | "print(len(loo))\n", 68 | "print(list(loo)[-1])\n", 69 | "print(list(loo.values())[-1].keys())" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "print(len(mutations))\n", 79 | "print(list(loo.values())[0][\"median\"][\"rate_loc\"].shape)" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "best_rate_loc = None\n", 89 | "loo_rate_loc = {}\n", 90 | "for k, v in loo.items():\n", 91 | " rate = quotient_central_moments(v[\"median\"][\"rate_loc\"], clade_id_to_lineage_id)[1]\n", 92 | " holdout = k[-1]\n", 93 | " if holdout:\n", 94 | " key = holdout[-1][-1][-1][-1].replace(\"$\", \"\").replace(\"^\", \"\")\n", 95 | " loo_rate_loc[key] = rate\n", 96 | " else:\n", 97 | " best_rate_loc = rate\n", 98 | "print(\" \".join(loo_rate_loc))" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "def plot_prediction(filenames=[], debug=False, use_who=True):\n", 108 | " X1, Y1, X2, Y2, labels, debug_labels = [], [], [], [], [], []\n", 109 | " who = {vs[0]: k for k, vs in pangolin.WHO_ALIASES.items()}\n", 110 | " ancestors = set(lineage_id)\n", 111 | " for child, rate_loc in loo_rate_loc.items():\n", 112 | " parent = pangolin.compress(\n", 113 | " pangolin.get_most_recent_ancestor(\n", 114 | " pangolin.decompress(child), ancestors\n", 115 | " )\n", 116 | " )\n", 117 | " c = lineage_id[child]\n", 118 | " p = lineage_id[parent]\n", 119 | " truth = best_rate_loc[c].item()\n", 120 | " baseline = rate_loc[p].item()\n", 121 | " guess = rate_loc[c].item()\n", 122 | " Y1.append(truth)\n", 123 | " X1.append(guess)\n", 124 | " Y2.append(truth - baseline)\n", 125 | " X2.append(guess - baseline)\n", 126 | " labels.append(who.get(child))\n", 127 | " debug_labels.append(child)\n", 128 | " mae = np.abs(np.array(Y2)).mean()\n", 129 | " print(f\"MAE(baseline - full estimate) = {mae:0.4g}\")\n", 130 | " fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))\n", 131 | " for ax, X, Y in zip(axes, [X1, X2], [Y1, Y2]):\n", 132 | " X = np.array(X)\n", 133 | " Y = np.array(Y)\n", 134 | " ax.scatter(X, Y, 40, lw=0, alpha=1, color=\"white\", zorder=-5)\n", 135 | " ax.scatter(X, Y, 20, lw=0, alpha=0.3, color=\"darkred\")\n", 136 | " lb = min(min(X), min(Y))\n", 137 | " ub = max(max(X), max(Y))\n", 138 | " d = ub - lb\n", 139 | " lb -= 0.03 * d\n", 140 | " ub += 0.05 * d\n", 141 | " ax.plot([lb, ub], [lb, ub], \"k--\", alpha=0.2, zorder=-10)\n", 142 | " ax.set_xlim(lb, ub)\n", 143 | " ax.set_ylim(lb, ub)\n", 144 | " rho = pearson_correlation(X, Y)\n", 145 | " mae = np.abs(X - Y).mean()\n", 146 | " ax.text(0.3 * lb + 0.7 * ub, 0.8 * lb + 0.2 * ub,\n", 147 | " #f\" ρ = {rho:0.3f}\\nMAE = {mae:0.3g}\",\n", 148 | " f\" ρ = {rho:0.3f}\",\n", 149 | " backgroundcolor=\"white\", ha=\"center\", va=\"center\")\n", 150 | " for x, y, label, debug_label in zip(X, Y, labels, debug_labels):\n", 151 | " pad = 0.012\n", 152 | " if label is not None:\n", 153 | " ax.plot([x], [y], \"ko\", mfc=\"#c77\", c=\"black\", ms=4, mew=0.5)\n", 154 | " ax.text(x, y + pad, label if use_who else debug_label,\n", 155 | " va=\"bottom\", ha=\"center\", fontsize=6)\n", 156 | " elif abs(x - y) > 0.2:\n", 157 | " ax.plot([x], [y], \"ko\", mfc=\"#c77\", c=\"black\", ms=4, mew=0.5)\n", 158 | " ax.text(x, y + pad, debug_label, va=\"bottom\", ha=\"center\", fontsize=6)\n", 159 | " axes[0].set_ylabel(\"full estimate\")\n", 160 | " axes[0].set_xlabel(\"LOO estimate\")\n", 161 | " axes[1].set_ylabel(\"full estimate − baseline\")\n", 162 | " axes[1].set_xlabel(\"LOO estimate − baseline\")\n", 163 | " plt.tight_layout()\n", 164 | " for f in filenames:\n", 165 | " plt.savefig(f)\n", 166 | "plot_prediction(debug=True)\n", 167 | "plot_prediction(use_who=False, filenames=[\"paper/lineage_prediction.png\"])" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [] 176 | } 177 | ], 178 | "metadata": { 179 | "kernelspec": { 180 | "display_name": "Python 3", 181 | "language": "python", 182 | "name": "python3" 183 | }, 184 | "language_info": { 185 | "codemirror_mode": { 186 | "name": "ipython", 187 | "version": 3 188 | }, 189 | "file_extension": ".py", 190 | "mimetype": "text/x-python", 191 | "name": "python", 192 | "nbconvert_exporter": "python", 193 | "pygments_lexer": "ipython3", 194 | "version": "3.6.12" 195 | } 196 | }, 197 | "nbformat": 4, 198 | "nbformat_minor": 4 199 | } 200 | -------------------------------------------------------------------------------- /notebooks/paper: -------------------------------------------------------------------------------- 1 | ../paper -------------------------------------------------------------------------------- /notebooks/preprocess-pairwise-allele-fractions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "96af8117", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "id": "7e9fa3b9", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "max_num_clades = 3000\n", 21 | "min_num_mutations = 1" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "id": "eb793068", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "clade_id \tdict of size 3000\n", 35 | "clade_id_inv \tlist of size 3000\n", 36 | "clade_id_to_lineage_id \tTensor of shape (3000,)\n", 37 | "clade_to_lineage \tdict of size 3000\n", 38 | "features \tTensor of shape (3000, 2904)\n", 39 | "lineage_id \tdict of size 1544\n", 40 | "lineage_id_inv \tlist of size 1544\n", 41 | "lineage_id_to_clade_id \tTensor of shape (1544,)\n", 42 | "lineage_to_clade \tdict of size 1544\n", 43 | "location_id \tOrderedDict of size 1560\n", 44 | "location_id_inv \tlist of size 1560\n", 45 | "mutations \tlist of size 2904\n", 46 | "pc_index \tTensor of shape (185001,)\n", 47 | "sparse_counts \tdict of size 3\n", 48 | "state_to_country \tTensor of shape (1355,)\n", 49 | "time \tTensor of shape (56,)\n", 50 | "weekly_clades \tTensor of shape (56, 1560, 3000)\n", 51 | "CPU times: user 15 ms, sys: 891 ms, total: 906 ms\n", 52 | "Wall time: 1.46 s\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "%%time\n", 58 | "def load_data():\n", 59 | " filename = f\"results/mutrans.data.single.{max_num_clades}.{min_num_mutations}.50.None.pt\"\n", 60 | " dataset = torch.load(filename, map_location=\"cpu\")\n", 61 | " return dataset\n", 62 | "dataset = load_data()\n", 63 | "locals().update(dataset)\n", 64 | "for k, v in sorted(dataset.items()):\n", 65 | " if isinstance(v, torch.Tensor):\n", 66 | " print(f\"{k} \\t{type(v).__name__} of shape {tuple(v.shape)}\")\n", 67 | " else:\n", 68 | " print(f\"{k} \\t{type(v).__name__} of size {len(v)}\")" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "ffd0f8ed", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "# TODO(martin) Use features, mutations, weekly_clades" 79 | ] 80 | } 81 | ], 82 | "metadata": { 83 | "kernelspec": { 84 | "display_name": "Python 3 (ipykernel)", 85 | "language": "python", 86 | "name": "python3" 87 | }, 88 | "language_info": { 89 | "codemirror_mode": { 90 | "name": "ipython", 91 | "version": 3 92 | }, 93 | "file_extension": ".py", 94 | "mimetype": "text/x-python", 95 | "name": "python", 96 | "nbconvert_exporter": "python", 97 | "pygments_lexer": "ipython3", 98 | "version": "3.9.1" 99 | } 100 | }, 101 | "nbformat": 4, 102 | "nbformat_minor": 5 103 | } 104 | -------------------------------------------------------------------------------- /notebooks/results: -------------------------------------------------------------------------------- 1 | ../results -------------------------------------------------------------------------------- /paper/.gitignore: -------------------------------------------------------------------------------- 1 | *.aux 2 | *.bbl 3 | *.blg 4 | *.pdf 5 | *.out 6 | -------------------------------------------------------------------------------- /paper/BA.1_logistic_regression_all_US.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/BA.1_logistic_regression_all_US.jpg -------------------------------------------------------------------------------- /paper/Makefile: -------------------------------------------------------------------------------- 1 | all: supplement.pdf 2 | 3 | push: 4 | cp supplement.pdf ../data/ 5 | 6 | supplement.pdf: supplement.tex FORCE 7 | pdflatex supplement 8 | bibtex supplement 9 | pdflatex supplement 10 | 11 | FORCE: 12 | -------------------------------------------------------------------------------- /paper/N_S_peak_ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/N_S_peak_ratio.png -------------------------------------------------------------------------------- /paper/N_S_ratio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/N_S_ratio.png -------------------------------------------------------------------------------- /paper/README.md: -------------------------------------------------------------------------------- 1 | # Images and data for publication 2 | 3 | This directory contains figures and tables output by the [PyR0 4 | model](https://www.medrxiv.org/content/10.1101/2021.09.07.21263228v1). These 5 | outputs are aggregated to weeks, PANGO lineages, and amino acid changes. 6 | 7 | Figures and tables are generated by first running preprocessing and inference, 8 | then postprocessing with the following Jupyter notebooks: 9 | [ `mutrans.ipynb` ](../mutrans.ipynb), 10 | [ `mutrans_gene.ipynb` ](../mutrans_gene.ipynb), 11 | [ `mutrans_prediction.ipynb` ](../mutrans_prediction.ipynb), 12 | [ `mutrans_backtesting.ipynb` ](../mutrans_backtesting.ipynb). 13 | 14 | ## Data tables 15 | 16 | - [Mutation table](mutations.tsv) is ranked by statistical significance. 17 | The "mean" field denotes the estimated effect on log growth rate of each mutation. 18 | - [Lineage table](strains.tsv) is ranked by growth rate. 19 | 20 | ## Manhattan plots 21 | 22 | ![Manhattan plot of entire genome](manhattan.png) 23 | ![Manhattan plot of N gene](manhattan_N.png) 24 | ![Manhattan plot of S gene](manhattan_S.png) 25 | ![Manhattan plot of ORF1a gene](manhattan_ORF1a.png) 26 | ![Manhattan plot of ORF1b gene](manhattan_ORF1b.png) 27 | ![Manhattan plot of ORF3a gene](manhattan_ORF3a.png) 28 | 29 | ## Information density plots 30 | 31 | ![How informative is each gene?](vary_gene_likelihood.png) 32 | ![How informative is each NSP?](vary_nsp_likelihood.png) 33 | 34 | ## Volcano plot 35 | 36 | ![Volcano plot of mutations](volcano.png) 37 | 38 | ## Strain characterization plots 39 | 40 | ![Growth rate versus emergence date](strain_emergence.png) 41 | ![Growth rate versus case count](strain_prevalence.png) 42 | ![PANGO lineage heterogeneity](lineage_heterogeneity.png) 43 | ![Forecast](forecast.png) 44 | ![Deep scanning](deep_scanning.png) 45 | 46 | ## Cross validation plots 47 | 48 | The following plots assess robustness via 2-fold crossvalidation, splitting data into Europe versus (World w/o Europe). 49 | 50 | ![Lineage correlation](lineage_agreement.png) 51 | ![Mutation correlation](mutation_agreement.png) 52 | ![Lineage box plots](strain_europe_boxplot.png) 53 | ![Mutation box plots](mutation_europe_boxplot_rankby_s.png) 54 | ![Mutation box plots](mutation_europe_boxplot_rankby_t.png) 55 | ![Lineage prediction](lineage_prediction.png) 56 | 57 | ## Data plots 58 | 59 | ![Distribution of samples among regions](region_distribution.png) 60 | ![Distribution of samples among clades](clade_distribution.png) 61 | 62 | ## Acknowledgements 63 | 64 | The aggregated model outputs in this directory were generated from data inputs 65 | including either GISAID records (https://gisaid.org), an UShER tree placement 66 | of those records 67 | (http://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/UShER_SARS-CoV-2), PANGO 68 | lineage classifications (https://cov-lineages.org), and case count time series 69 | from Johns-Hopkins University (https://github.com/CSSEGISandData/COVID-19). 70 | Results in this directory can alternatively be generated using GENBANK records 71 | (https://www.ncbi.nlm.nih.gov) instead of GISAID records. 72 | 73 | We gratefully acknowledge all data contributors, i.e. the Authors and their Originating laboratories responsible for obtaining the specimens, and their Submitting laboratories for generating the genetic sequence and metadata and sharing via the GISAID initiative [1,2] on which this research is based. A total of 6,466,299 submissions are included in this study. A complete list of the 6,466,299 accession numbers is available in [accession_ids.txt.xz](accession_ids.txt.xz). 74 | 75 | 1. GISAID Initiative and global contributors, 76 | EpiCoV(TM) human coronavirus 2019 database. 77 | GISAID (2020), (available at https://gisaid.org). 78 | 2. S. Elbe, G. Buckland-Merrett, 79 | Data, disease and diplomacy: GISAID's innovative contribution to global health. 80 | Glob Chall. 1, 33-46 (2017). 81 | 3. National Center for Biotechnology Information (NCBI)[Internet]. 82 | Bethesda (MD): National Library of Medicine (US), 83 | National Center for Biotechnology Information; 84 | [1988] – [cited 2017 Apr 06]. 85 | Available from: https://www.ncbi.nlm.nih.gov 86 | -------------------------------------------------------------------------------- /paper/accession_ids.txt.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/accession_ids.txt.xz -------------------------------------------------------------------------------- /paper/backtesting/L1_error_barplot_England.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_England.png -------------------------------------------------------------------------------- /paper/backtesting/L1_error_barplot_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_all.png -------------------------------------------------------------------------------- /paper/backtesting/L1_error_barplot_other.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_other.png -------------------------------------------------------------------------------- /paper/backtesting/L1_error_barplot_top100-1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_top100-1000.png -------------------------------------------------------------------------------- /paper/backtesting/L1_error_barplot_top100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_top100.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_248_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_248_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_346_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_346_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_360_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_360_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_430.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_430.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_444.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_444.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_458.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_458.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_472.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_472.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_486.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_486.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_500.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_514.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_514.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_514_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_514_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_528.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_528.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_528_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_528_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_542.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_542.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_556.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_556.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_570.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_570.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_584.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_584.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_598.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_598.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_612.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_612.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_612_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_612_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_626.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_626.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_640.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_640.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_640_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_640_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_654.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_654.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_668.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_668.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_682.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_682.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_696.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_696.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_710.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_710.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_710_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_710_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_724.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_724.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_724_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_724_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_738.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_738.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_738_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_738_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_752.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_752.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_752_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_752_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_Asia_Europe_Africa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_Asia_Europe_Africa.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap_nolegend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap_nolegend.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_USA_France_England_Brazil_Australia_Russia.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_USA_France_England_Brazil_Australia_Russia.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap_nolegend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap_nolegend.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_england.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_england.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap.png -------------------------------------------------------------------------------- /paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap_nolegend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap_nolegend.png -------------------------------------------------------------------------------- /paper/binding_retained.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/binding_retained.png -------------------------------------------------------------------------------- /paper/clade_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/clade_distribution.png -------------------------------------------------------------------------------- /paper/coef_scale_manhattan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_manhattan.png -------------------------------------------------------------------------------- /paper/coef_scale_pearson.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_pearson.png -------------------------------------------------------------------------------- /paper/coef_scale_scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_scatter.png -------------------------------------------------------------------------------- /paper/coef_scale_volcano.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_volcano.png -------------------------------------------------------------------------------- /paper/convergence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/convergence.png -------------------------------------------------------------------------------- /paper/convergence_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/convergence_loss.png -------------------------------------------------------------------------------- /paper/convergent_evolution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/convergent_evolution.jpg -------------------------------------------------------------------------------- /paper/deep_scanning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/deep_scanning.png -------------------------------------------------------------------------------- /paper/delta_logR_breakdown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/delta_logR_breakdown.png -------------------------------------------------------------------------------- /paper/dobanno_2021_immunogenicity_N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/dobanno_2021_immunogenicity_N.png -------------------------------------------------------------------------------- /paper/forecast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/forecast.png -------------------------------------------------------------------------------- /paper/formatted/Figure1.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure1.ai -------------------------------------------------------------------------------- /paper/formatted/Figure1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure1.jpg -------------------------------------------------------------------------------- /paper/formatted/Figure2.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2.ai -------------------------------------------------------------------------------- /paper/formatted/Figure2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2.jpg -------------------------------------------------------------------------------- /paper/formatted/Figure2_alt.ai: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2_alt.ai -------------------------------------------------------------------------------- /paper/formatted/Figure2_alt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2_alt.jpg -------------------------------------------------------------------------------- /paper/gene_ratios.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/gene_ratios.png -------------------------------------------------------------------------------- /paper/ind_emergences_violin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/ind_emergences_violin.png -------------------------------------------------------------------------------- /paper/lineage_agreement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/lineage_agreement.png -------------------------------------------------------------------------------- /paper/lineage_heterogeneity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/lineage_heterogeneity.png -------------------------------------------------------------------------------- /paper/lineage_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/lineage_prediction.png -------------------------------------------------------------------------------- /paper/main.bib: -------------------------------------------------------------------------------- 1 | @article{bingham2019pyro, 2 | title={Pyro: Deep universal probabilistic programming}, 3 | author={Bingham, Eli and Chen, Jonathan P and Jankowiak, Martin and Obermeyer, Fritz and Pradhan, Neeraj and Karaletsos, Theofanis and Singh, Rohit and Szerlip, Paul and Horsfall, Paul and Goodman, Noah D}, 4 | journal={The Journal of Machine Learning Research}, 5 | volume={20}, 6 | number={1}, 7 | pages={973--978}, 8 | year={2019}, 9 | publisher={JMLR. org}, 10 | } 11 | 12 | @misc{aksamentov2020nextclade, 13 | title={Nextclade}, 14 | author={Aksamentov, Ivan and Neher, Richard}, 15 | url={https://github.com/nextstrain/nextclade}, 16 | } 17 | 18 | @inproceedings{gorinova2020automatic, 19 | title={Automatic reparameterisation of probabilistic programs}, 20 | author={Gorinova, Maria and Moore, Dave and Hoffman, Matthew}, 21 | booktitle={International Conference on Machine Learning}, 22 | pages={3648--3657}, 23 | year={2020}, 24 | organization={PMLR}, 25 | } 26 | 27 | @inproceedings{paszke2017automatic, 28 | title={Automatic differentiation in PyTorch}, 29 | author={Paszke, Adam and Gross, Sam and Chintala, Soumith and Chanan, Gregory and Yang, Edward and DeVito, Zachary and Lin, Zeming and Desmaison, Alban and Antiga, Luca and Lerer, Adam}, 30 | booktitle={NIPS-W}, 31 | year={2017}, 32 | } 33 | 34 | @inproceedings{elbe2017gisaid, 35 | author={Elbe, S. and Buckland-Merrett, G.}, 36 | year={2017}, 37 | title={Data, disease and diplomacy: GISAID’s innovative contribution to global health}, 38 | booktitle={Global Challenges}, 39 | section={1}, 40 | pages={33-46}, 41 | doi={DOI:10.1002/gch2.1018}, 42 | pmcid={PMCID: 31565258}, 43 | } 44 | 45 | @inproceedings{rambaut2020dynamic, 46 | title={A dynamic nomenclature proposal for SARS-CoV-2 lineages to assist genomic epidemiology}, 47 | author={Rambaut, A. and Holmes, E.C. and O’Toole, Á and Hill, V. and McCrone, J. T. and Ruis, C. and du Plessis, L. and Pybus, O. G.}, 48 | year={2020}, 49 | booktitle={Nature Microbiology}, 50 | doi={DOI:10.1038/s41564-020-0770-5}, 51 | } 52 | 53 | @inproceedings{neal2003slice, 54 | author={Neal, Radford M.}, 55 | year={2003}, 56 | title={Slice Sampling}, 57 | booktitle={Annals of Statistics}, 58 | volume={31}, 59 | number={3}, 60 | pages={705–67}, 61 | } 62 | -------------------------------------------------------------------------------- /paper/main.md: -------------------------------------------------------------------------------- 1 | # Analysis of 1,000,000 virus genomes reveals determinants of relative growth rate 2 | 3 | Fritz Obermeyer, Jacob Lemieux, Stephen Schaffner, Daniel Park 4 | 5 | ## Abstract 6 | 7 | We fit a large logistic growth regression model to sequenced genomes of SARS-CoV-2 samples distributed in space (globally) and time (late 2019 -- early 2021). 8 | This model attributes relative changes in transmissibility to each observed mutation, aggregated over >800 PANGO lineages. 9 | Results show that one region of the N gene is particularly important in determining transmissibility, followed by the S gene and the ORF3a gene. 10 | 11 | ## Discussion 12 | 13 | Note that many of the most transmissibility-influencing mutation occur in the N gene and particularly at positions 28800--29000 (see Figure TODO). 14 | 15 | ## Materials and methods 16 | 17 | ### Model 18 | 19 | We model the strain distribution over time as a softmax of a set of linear growth functions, one per strain; this is the multivariate generalization of the standard logistic growth model. 20 | Further, we factorize the growth rate of each strain as a sum over growth rates of the mutations in that strain: 21 | ``` 22 | strain_portion = softmax(strain_init + mutation_rate @ mutations_per_strain) 23 | ``` 24 | Our hierarchical probabilistic model includes latent variables for: 25 | - global feature regularization parameter (how sparse are the features?). 26 | - initial strain prevalence in each (region, strain). 27 | - rate coefficients of each mutation, with a Laplace prior. 28 | - global observation concentration (how noisy are the observations?). 29 | 30 | The model uses a Dirichlet-multinomial likelihood (the multivariate 31 | generalization of negative binomial) with learned shared concentration parameter. 32 | 33 | TODO add graphical model figure 34 | 35 | This model is robust to a number of sources of bias: 36 | - Sampling bias across regions (it's fine for one region to sample 100x more than another) 37 | - Sampling bias over time (it's fine to change sampling rate over time) 38 | - Change in absolute growth rate of all strains, in any (region, time) cell (i.e. the model is robust to changes in local policies or weather, as long as those changes equally affect all strains). 39 | 40 | However the model is susceptible to the following sources of bias: 41 | - Biased sampling in any (region,time) cell (e.g. sequencing only in case of S-gene target failure). 42 | - Changes in sampling bias within a single region over time (e.g. a country has a lab in only one city, then spins up a second lab in another distant city with different strain portions). 43 | 44 | TODO We considered the possibility of biased submission to the GISAID database, and analyzed the CDC NS3 dataset, finding similar results. 45 | 46 | ## Inference 47 | 48 | We use a combination of MAP estimation and stochastic variational inference. 49 | We MAP estimate the global parameters and initial strain prevalences. 50 | For the remaining mutation-transmissibility latent variable we would like to estimate statistical significance, so we fit a mean-field variational distribution: independent normal distributions with learned means and variances. 51 | While the mean field assumption leads to underdispersion on an absolute scale, it nevertheless allows us to rank the _relative statistical significance_ of different mutations. 52 | This relative ranking is sufficient for the purposes of this paper: to focus attention on mutations and regions of the SARS-CoV-2 genome that appear to be linked to increased transmissibility. 53 | 54 | Inverence is performed via the Pyro probabilistic programming language. 55 | Source code is available at [github.com/broadinstitute/pyro-cov](https://github.com/broadinstitute/pyro-cov). 56 | 57 | ## References 58 | 59 | - Eli Bingham, Jonathan P. Chen, Martin Jankowiak, Fritz Obermeyer, Neeraj Pradhan, 60 | Theofanis Karaletsos, Rohit Singh, Paul Szerlip, Paul Horsfall, Noah D. Goodman 61 | (2018) 62 | "Pyro: Deep Universal Probabilistic Programming" 63 | https://arxiv.org/abs/1810.09538 64 | 65 | ## Tables 66 | 67 | 1. [Top Mutations](top_mutations.md) 68 | 2. [Top Strains](top_strains.md) 69 | 70 | ## Figures 71 | 72 | [![manhattan](manhattan.png)](manhattan.pdf) 73 | 74 | [![volcano](volcano.png)](volcano.pdf) 75 | -------------------------------------------------------------------------------- /paper/manhattan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan.png -------------------------------------------------------------------------------- /paper/manhattan_N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_N.png -------------------------------------------------------------------------------- /paper/manhattan_N_coef_scale_0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_N_coef_scale_0.05.png -------------------------------------------------------------------------------- /paper/manhattan_ORF1a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1a.png -------------------------------------------------------------------------------- /paper/manhattan_ORF1a_coef_scale_0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1a_coef_scale_0.05.png -------------------------------------------------------------------------------- /paper/manhattan_ORF1b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1b.png -------------------------------------------------------------------------------- /paper/manhattan_ORF1b_coef_scale_0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1b_coef_scale_0.05.png -------------------------------------------------------------------------------- /paper/manhattan_ORF3a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF3a.png -------------------------------------------------------------------------------- /paper/manhattan_ORF3a_coef_scale_0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF3a_coef_scale_0.05.png -------------------------------------------------------------------------------- /paper/manhattan_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_S.png -------------------------------------------------------------------------------- /paper/manhattan_S_coef_scale_0.05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_S_coef_scale_0.05.png -------------------------------------------------------------------------------- /paper/moran.csv: -------------------------------------------------------------------------------- 1 | ,NumMutations,GeneSize,PValue,Lengthscale 2 | EntireGenome,2904.0,29394.0,1e-06,100.0 3 | EntireGenome,2904.0,29394.0,1e-06,500.0 4 | S,415.0,3786.0,0.00191,50.0 5 | N,220.0,1251.0,0.017627,50.0 6 | ORF7a,75.0,360.0,0.024066,18.0 7 | ORF3a,198.0,789.0,0.024307,39.45 8 | ORF1a,1107.0,13182.0,0.02971,50.0 9 | ORF7b,26.0,126.0,0.089589,6.3 10 | ORF14,69.0,213.0,0.112527,10.65 11 | ORF6,19.0,177.0,0.138634,8.85 12 | ORF1b,552.0,8052.0,0.329416,50.0 13 | E,17.0,195.0,0.455606,9.75 14 | M,42.0,639.0,0.518497,31.95 15 | ORF8,91.0,357.0,0.810603,17.85 16 | ORF10,18.0,102.0,0.864965,5.1 17 | ORF9b,55.0,240.0,0.927562,12.0 18 | -------------------------------------------------------------------------------- /paper/multinomial_LR_vs_pyr0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/multinomial_LR_vs_pyr0.jpg -------------------------------------------------------------------------------- /paper/mutation_agreement.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_agreement.png -------------------------------------------------------------------------------- /paper/mutation_europe_boxplot_rankby_s.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_europe_boxplot_rankby_s.png -------------------------------------------------------------------------------- /paper/mutation_europe_boxplot_rankby_t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_europe_boxplot_rankby_t.png -------------------------------------------------------------------------------- /paper/mutation_scoring/by_gene_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/by_gene_heatmap.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_individual_mutation_significance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_individual_mutation_significance_M.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance_M.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_individual_mutation_significance_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance_S.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_individual_mutation_significance_S__K_to_N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance_S__K_to_N.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_k_to_n_mutation_significance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_k_to_n_mutation_significance.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_per_gene_aggregate_sign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_per_gene_aggregate_sign.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_E.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_E.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_N.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_N.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF10.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF1a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF1a.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF1b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF1b.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF3a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF3a.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF6.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF7b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF7b.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF8.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_ORF9b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF9b.png -------------------------------------------------------------------------------- /paper/mutation_scoring/plot_pval_vs_top_genes_S.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_S.png -------------------------------------------------------------------------------- /paper/mutation_scoring/pvals_vs_top_genes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/pvals_vs_top_genes.png -------------------------------------------------------------------------------- /paper/mutation_scoring/top_substitutions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/top_substitutions.png -------------------------------------------------------------------------------- /paper/mutation_summaries.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_summaries.jpg -------------------------------------------------------------------------------- /paper/region_distribution.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/region_distribution.png -------------------------------------------------------------------------------- /paper/relative_growth_rate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/relative_growth_rate.jpg -------------------------------------------------------------------------------- /paper/schematic_overview.key: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/schematic_overview.key -------------------------------------------------------------------------------- /paper/schematic_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/schematic_overview.png -------------------------------------------------------------------------------- /paper/scibib.bib: -------------------------------------------------------------------------------- 1 | % scibib.bib 2 | 3 | % This is the .bib file used to compile the document "A simple Science 4 | % template" (scifile.tex). It is not intended as an example of how to 5 | % set up your BibTeX file. 6 | 7 | 8 | 9 | 10 | @misc{tth, note = "The package is TTH, available at 11 | http://hutchinson.belmont.ma.us/tth/ ."} 12 | 13 | @misc{use2e, note = "As the mark-up of the \TeX\ source for this 14 | document makes clear, your file should be coded in \LaTeX 15 | 2${\varepsilon}$, not \LaTeX\ 2.09 or an earlier release. Also, 16 | please use the \texttt{article} document class."} 17 | 18 | @misc{inclme, note="Among whom are the author of this document. The 19 | ``real'' references and notes contained herein were compiled using 20 | B{\small{IB}}\TeX\ from the sample .bib file \texttt{scibib.bib}, the style 21 | package \texttt{scicite.sty}, and the bibliography style file 22 | \texttt{Science.bst}."} 23 | 24 | 25 | @misc{nattex, note="One of the equation editors we use, Equation Magic 26 | (MicroPress Inc., Forest Hills, NY; http://www.micropress-inc.com/), 27 | interprets native \TeX\ source code and generates an equation as an 28 | OLE picture object that can then be cut and pasted directly into Word. 29 | This editor, however, does not handle \LaTeX\ environments (such as 30 | \texttt{\{array\}} or \texttt{\{eqnarray\}}); it can interpret only 31 | \TeX\ codes. Thus, when there's a choice, we ask that you avoid these 32 | \LaTeX\ calls in displayed math --- for example, that you use the 33 | \TeX\ \verb+\matrix+ command for ordinary matrices, rather than the 34 | \LaTeX\ \texttt{\{array\}} environment."} 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /paper/spectrum_transmissibility.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/spectrum_transmissibility.jpg -------------------------------------------------------------------------------- /paper/strain_emergence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/strain_emergence.png -------------------------------------------------------------------------------- /paper/strain_europe_boxplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/strain_europe_boxplot.png -------------------------------------------------------------------------------- /paper/strain_prevalence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/strain_prevalence.png -------------------------------------------------------------------------------- /paper/supplement_table_3_02_05_2022.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/supplement_table_3_02_05_2022.jpeg -------------------------------------------------------------------------------- /paper/table2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table2.jpeg -------------------------------------------------------------------------------- /paper/table_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table_1.jpeg -------------------------------------------------------------------------------- /paper/table_1.tsv: -------------------------------------------------------------------------------- 1 | Rank Gene Substitution "Fold Increase 2 | in Fitness" "Number of 3 | Lineages" 4 | 0 1 S R346K 1.098 29 5 | 1 2 S L452Q 1.092 9 6 | 2 3 S V213G 1.057 158 7 | 3 4 S S704L 1.080 9 8 | 4 5 S T19I 1.053 158 9 | 5 6 ORF1b R1315C 1.047 159 10 | 6 7 ORF1b T2163I 1.056 159 11 | 7 8 ORF3a T223I 1.052 178 12 | 8 9 ORF1a T3090I 1.042 159 13 | 9 10 M D3N 1.127 76 14 | 10 11 ORF10 L37F 1.067 22 15 | 11 12 N S413R 1.049 159 16 | 12 13 ORF9b D16G 1.091 36 17 | 13 14 S T376A 1.040 158 18 | 14 15 ORF1a L3027F 1.043 159 19 | 15 16 ORF1a T842I 1.036 169 20 | 16 17 S D796Y 1.063 174 21 | 17 18 S L452R 1.177 367 22 | 18 19 S S371F 1.039 158 23 | 19 20 S Q954H 1.066 211 24 | -------------------------------------------------------------------------------- /paper/table_1_02_05_2022.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table_1_02_05_2022.jpeg -------------------------------------------------------------------------------- /paper/table_2_02_05_2022.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table_2_02_05_2022.jpeg -------------------------------------------------------------------------------- /paper/vary_gene_elbo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_gene_elbo.png -------------------------------------------------------------------------------- /paper/vary_gene_likelihood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_gene_likelihood.png -------------------------------------------------------------------------------- /paper/vary_gene_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_gene_loss.png -------------------------------------------------------------------------------- /paper/vary_nsp_elbo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_nsp_elbo.png -------------------------------------------------------------------------------- /paper/vary_nsp_likelihood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_nsp_likelihood.png -------------------------------------------------------------------------------- /paper/volcano.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/volcano.png -------------------------------------------------------------------------------- /pyrocov/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | __version__ = "0.1.0" 5 | -------------------------------------------------------------------------------- /pyrocov/aa.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | DNA_TO_AA = { 5 | "TTT": "F", 6 | "TTC": "F", 7 | "TTA": "L", 8 | "TTG": "L", 9 | "CTT": "L", 10 | "CTC": "L", 11 | "CTA": "L", 12 | "CTG": "L", 13 | "ATT": "I", 14 | "ATC": "I", 15 | "ATA": "I", 16 | "ATG": "M", 17 | "GTT": "V", 18 | "GTC": "V", 19 | "GTA": "V", 20 | "GTG": "V", 21 | "TCT": "S", 22 | "TCC": "S", 23 | "TCA": "S", 24 | "TCG": "S", 25 | "CCT": "P", 26 | "CCC": "P", 27 | "CCA": "P", 28 | "CCG": "P", 29 | "ACT": "T", 30 | "ACC": "T", 31 | "ACA": "T", 32 | "ACG": "T", 33 | "GCT": "A", 34 | "GCC": "A", 35 | "GCA": "A", 36 | "GCG": "A", 37 | "TAT": "Y", 38 | "TAC": "Y", 39 | "TAA": None, # stop 40 | "TAG": None, # stop 41 | "CAT": "H", 42 | "CAC": "H", 43 | "CAA": "Q", 44 | "CAG": "Q", 45 | "AAT": "N", 46 | "AAC": "N", 47 | "AAA": "K", 48 | "AAG": "K", 49 | "GAT": "D", 50 | "GAC": "D", 51 | "GAA": "E", 52 | "GAG": "E", 53 | "TGT": "C", 54 | "TGC": "C", 55 | "TGA": None, # stop 56 | "TGG": "W", 57 | "CGT": "R", 58 | "CGC": "R", 59 | "CGA": "R", 60 | "CGG": "R", 61 | "AGT": "S", 62 | "AGC": "S", 63 | "AGA": "R", 64 | "AGG": "R", 65 | "GGT": "G", 66 | "GGC": "G", 67 | "GGA": "G", 68 | "GGG": "G", 69 | } 70 | -------------------------------------------------------------------------------- /pyrocov/align.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | 6 | # Source: https://samtools.github.io/hts-specs/SAMv1.pdf 7 | CIGAR_CODES = "MIDNSHP=X" # Note minimap2 uses only "MIDNSH" 8 | 9 | ROOT = os.path.dirname(os.path.dirname(__file__)) 10 | NEXTCLADE_DATA = os.path.expanduser("~/github/nextstrain/nextclade/data/sars-cov-2") 11 | PANGOLEARN_DATA = os.path.expanduser("~/github/cov-lineages/pangoLEARN/pangoLEARN/data") 12 | -------------------------------------------------------------------------------- /pyrocov/distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import math 5 | 6 | import torch 7 | from pyro.distributions import TorchDistribution 8 | from torch.distributions import constraints 9 | from torch.distributions.utils import broadcast_all 10 | 11 | 12 | class SoftLaplace(TorchDistribution): 13 | """ 14 | Smooth distribution with Laplace-like tail behavior. 15 | """ 16 | 17 | arg_constraints = {"loc": constraints.real, "scale": constraints.positive} 18 | support = constraints.real 19 | has_rsample = True 20 | 21 | def __init__(self, loc, scale, *, validate_args=None): 22 | self.loc, self.scale = broadcast_all(loc, scale) 23 | super().__init__(self.loc.shape, validate_args=validate_args) 24 | 25 | def expand(self, batch_shape, _instance=None): 26 | new = self._get_checked_instance(SoftLaplace, _instance) 27 | batch_shape = torch.Size(batch_shape) 28 | new.loc = self.loc.expand(batch_shape) 29 | new.scale = self.scale.expand(batch_shape) 30 | super(SoftLaplace, new).__init__(batch_shape, validate_args=False) 31 | new._validate_args = self._validate_args 32 | return new 33 | 34 | def log_prob(self, value): 35 | if self._validate_args: 36 | self._validate_sample(value) 37 | z = (value - self.loc) / self.scale 38 | return math.log(2 / math.pi) - self.scale.log() - torch.logaddexp(z, -z) 39 | 40 | def rsample(self, sample_shape=torch.Size()): 41 | shape = self._extended_shape(sample_shape) 42 | u = self.loc.new_empty(shape).uniform_() 43 | return self.icdf(u) 44 | 45 | def cdf(self, value): 46 | if self._validate_args: 47 | self._validate_sample(value) 48 | z = (value - self.loc) / self.scale 49 | return z.exp().atan().mul(2 / math.pi) 50 | 51 | def icdf(self, value): 52 | return value.mul(math.pi / 2).tan().log().mul(self.scale).add(self.loc) 53 | -------------------------------------------------------------------------------- /pyrocov/external/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/pyrocov/external/__init__.py -------------------------------------------------------------------------------- /pyrocov/external/usher/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Yatish Turakhia (https://turakhia.eng.ucsd.edu/), Haussler Lab (https://hausslergenomics.ucsc.edu/), Corbett-Detig Lab (https://corbett.ucsc.edu/) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyrocov/external/usher/README.md: -------------------------------------------------------------------------------- 1 | This directory contains code copied from https://github.com/yatisht/usher 2 | -------------------------------------------------------------------------------- /pyrocov/external/usher/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/pyrocov/external/usher/__init__.py -------------------------------------------------------------------------------- /pyrocov/fasta.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | 5 | class ShardedFastaWriter: 6 | """ 7 | Writer that splits into multiple fasta files to avoid nextclade file size 8 | limit. 9 | """ 10 | 11 | def __init__(self, filepattern, max_count=5000): 12 | assert filepattern.count("*") == 1 13 | self.filepattern = filepattern 14 | self.max_count = max_count 15 | self._file_count = 0 16 | self._line_count = 0 17 | self._file = None 18 | 19 | def _open(self): 20 | filename = self.filepattern.replace("*", str(self._file_count)) 21 | print(f"writing to {filename}") 22 | return open(filename, "wt") 23 | 24 | def __enter__(self): 25 | assert self._file is None 26 | self._file = self._open() 27 | self._file_count += 1 28 | return self 29 | 30 | def __exit__(self, *args, **kwargs): 31 | self._file.close() 32 | self._file = None 33 | self._file_count = 0 34 | self._line_count = 0 35 | 36 | def write(self, name, sequence): 37 | if self._line_count == self.max_count: 38 | self._file.close() 39 | self._file = self._open() 40 | self._file_count += 1 41 | self._line_count = 0 42 | self._file.write(">") 43 | self._file.write(name) 44 | self._file.write("\n") 45 | self._file.write(sequence) 46 | self._file.write("\n") 47 | self._line_count += 1 48 | -------------------------------------------------------------------------------- /pyrocov/hashsubset.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import hashlib 5 | import heapq 6 | 7 | 8 | class RandomSubDict: 9 | def __init__(self, max_size): 10 | assert isinstance(max_size, int) and max_size > 0 11 | self.key_to_value = {} 12 | self.hash_to_key = {} 13 | self.heap = [] 14 | self.max_size = max_size 15 | 16 | def __len__(self): 17 | return len(self.heap) 18 | 19 | def __setitem__(self, key, value): 20 | assert key not in self.key_to_value 21 | 22 | # Add (key,value) pair. 23 | self.key_to_value[key] = value 24 | h = hashlib.sha1(key.encode("utf-8")).hexdigest() 25 | self.hash_to_key[h] = key 26 | heapq.heappush(self.heap, h) 27 | 28 | # Truncate via k-min-hash. 29 | if len(self.heap) > self.max_size: 30 | h = heapq.heappop(self.heap) 31 | key = self.hash_to_key.pop(h) 32 | self.key_to_value.pop(key) 33 | 34 | def keys(self): 35 | assert len(self.key_to_value) <= self.max_size 36 | return self.key_to_value.keys() 37 | 38 | def values(self): 39 | assert len(self.key_to_value) <= self.max_size 40 | return self.key_to_value.values() 41 | 42 | def items(self): 43 | assert len(self.key_to_value) <= self.max_size 44 | return self.key_to_value.items() 45 | -------------------------------------------------------------------------------- /pyrocov/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import weakref 5 | from typing import Dict 6 | 7 | import torch 8 | 9 | 10 | def logistic_logsumexp(alpha, beta, delta, tau, *, backend="sequential"): 11 | """ 12 | Computes:: 13 | 14 | (alpha + beta * (delta + tau[:, None])).logsumexp(-1) 15 | 16 | where:: 17 | 18 | alpha.shape == [P, S] 19 | beta.shape == [P, S] 20 | delta.shape == [P, S] 21 | tau.shape == [T, P] 22 | 23 | :param str backend: One of "naive", "sequential". 24 | """ 25 | assert alpha.dim() == 2 26 | assert alpha.shape == beta.shape == delta.shape 27 | assert tau.dim() == 2 28 | assert tau.size(1) == alpha.size(0) 29 | assert not tau.requires_grad 30 | 31 | if backend == "naive": 32 | return (alpha + beta * (delta + tau[:, :, None])).logsumexp(-1) 33 | if backend == "sequential": 34 | return LogisticLogsumexp.apply(alpha, beta, delta, tau) 35 | raise ValueError(f"Unknown backend: {repr(backend)}") 36 | 37 | 38 | class LogisticLogsumexp(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, alpha, beta, delta, tau): 41 | P, S = alpha.shape 42 | T, P = tau.shape 43 | output = alpha.new_zeros(T, P) 44 | for t in range(len(tau)): 45 | logits = (delta + tau[t, :, None]).mul_(beta).add_(alpha) # [P, S] 46 | output[t] = logits.logsumexp(-1) # [P] 47 | 48 | ctx.save_for_backward(alpha, beta, delta, tau, output) 49 | return output # [T, P] 50 | 51 | @staticmethod 52 | def backward(ctx, grad_output): 53 | alpha, beta, delta, tau, output = ctx.saved_tensors 54 | 55 | grad_alpha = torch.zeros_like(alpha) # [P, S] 56 | grad_beta = torch.zeros_like(beta) # [P, S] 57 | for t in range(len(tau)): 58 | delta_tau = delta + tau[t, :, None] # [P, S] 59 | logits = (delta_tau * beta).add_(alpha) # [ P, S] 60 | softmax_logits = logits.sub_(output[t, :, None]).exp_() # [P, S] 61 | grad_logits = softmax_logits * grad_output[t, :, None] # [P, S] 62 | grad_alpha += grad_logits # [P, S] 63 | grad_beta += delta_tau * grad_logits # [P, S] 64 | 65 | grad_delta = beta * grad_alpha # [P, S] 66 | return grad_alpha, grad_beta, grad_delta, None 67 | 68 | 69 | _log_factorial_cache: Dict[int, torch.Tensor] = {} 70 | 71 | 72 | def log_factorial_sum(x: torch.Tensor) -> torch.Tensor: 73 | if x.requires_grad: 74 | return (x + 1).lgamma().sum() 75 | key = id(x) 76 | if key not in _log_factorial_cache: 77 | weakref.finalize(x, _log_factorial_cache.pop, key, None) # type: ignore 78 | _log_factorial_cache[key] = (x + 1).lgamma().sum() 79 | return _log_factorial_cache[key] 80 | 81 | 82 | def sparse_poisson_likelihood(full_log_rate, nonzero_log_rate, nonzero_value): 83 | """ 84 | The following are equivalent:: 85 | 86 | # Version 1. dense 87 | log_prob = Poisson(log_rate.exp()).log_prob(value).sum() 88 | 89 | # Version 2. sparse 90 | nnz = value.nonzero(as_tuple=True) 91 | log_prob = sparse_poisson_likelihood( 92 | log_rate.logsumexp(-1), 93 | log_rate[nnz], 94 | value[nnz], 95 | ) 96 | """ 97 | # Let p = Poisson(log_rate.exp()). Then 98 | # p.log_prob(value) 99 | # = log_rate * value - log_rate.exp() - (value + 1).lgamma() 100 | # p.log_prob(0) = -log_rate.exp() 101 | # p.log_prob(value) - p.log_prob(0) 102 | # = log_rate * value - log_rate.exp() - (value + 1).lgamma() + log_rate.exp() 103 | # = log_rate * value - (value + 1).lgamma() 104 | return ( 105 | torch.dot(nonzero_log_rate, nonzero_value) 106 | - log_factorial_sum(nonzero_value) 107 | - full_log_rate.exp().sum() 108 | ) 109 | 110 | 111 | def sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value): 112 | """ 113 | The following are equivalent:: 114 | 115 | # Version 1. dense 116 | log_prob = Multinomial(logits=logits).log_prob(value).sum() 117 | 118 | # Version 2. sparse 119 | nnz = value.nonzero(as_tuple=True) 120 | log_prob = sparse_multinomial_likelihood( 121 | value.sum(-1), 122 | (logits - logits.logsumexp(-1))[nnz], 123 | value[nnz], 124 | ) 125 | """ 126 | return ( 127 | log_factorial_sum(total_count) 128 | - log_factorial_sum(nonzero_value) 129 | + torch.dot(nonzero_logits, nonzero_value) 130 | ) 131 | 132 | 133 | def sparse_categorical_kl(log_q, p_support, log_p): 134 | """ 135 | Computes the restricted Kl divergence:: 136 | 137 | sum_i restrict(q)(i) (log q(i) - log p(i)) 138 | 139 | where ``p`` is a uniform prior, ``q`` is the posterior, and 140 | ``restrict(q))`` is the posterior restricted to the support of ``p`` and 141 | renormalized. Note for degenerate ``p=delta(i)`` this reduces to the log 142 | likelihood ``log q(i)``. 143 | """ 144 | assert log_q.dim() == 1 145 | assert log_p.dim() == 1 146 | assert p_support.shape == log_p.shape + log_q.shape 147 | q = log_q.exp() 148 | sum_q = torch.mv(p_support, q) 149 | sum_q_log_q = torch.mv(p_support, q * log_q) 150 | sum_r_log_q = sum_q_log_q / sum_q # restrict and normalize 151 | kl = sum_r_log_q - log_p # note sum_r_log_p = log_p because p is uniform 152 | return kl.sum() 153 | -------------------------------------------------------------------------------- /pyrocov/plotting.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import torch 5 | 6 | 7 | @torch.no_grad() 8 | def force_apart(*X, radius=[0.05, 0.005], iters=10, stepsize=2, xshift=0.01): 9 | X = torch.stack([torch.as_tensor(x) for x in X], dim=-1) 10 | assert len(X.shape) == 2 11 | radius = torch.as_tensor(radius) 12 | scale = X.max(0).values - X.min(0).values 13 | X /= scale 14 | for _ in range(iters): 15 | XX = X - X[:, None] 16 | r = (XX / radius).square().sum(-1, True) 17 | kernel = r.neg().exp() 18 | F = (XX * radius.square().sum() / radius**2 * kernel).sum(0) 19 | F_norm = F.square().sum(-1, True).sqrt().clamp(min=1e-20) 20 | F *= F_norm.neg().expm1().neg() / F_norm 21 | F *= stepsize 22 | if xshift is not None: 23 | F[:, 0].clamp_(min=0) 24 | X += F / iters 25 | if xshift is not None: 26 | X[:, 0] += xshift 27 | X *= scale 28 | return X.unbind(-1) 29 | -------------------------------------------------------------------------------- /pyrocov/sarscov2.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import re 6 | from collections import OrderedDict, defaultdict 7 | from typing import Dict, List, Tuple 8 | 9 | from .aa import DNA_TO_AA 10 | from .align import NEXTCLADE_DATA 11 | 12 | REFERENCE_SEQ = None # loaded lazily 13 | 14 | # Adapted from https://github.com/nextstrain/ncov/blob/50ceffa/defaults/annotation.gff 15 | # Note these are 1-based positions 16 | annotation_tsv = """\ 17 | seqname source feature start end score strand frame attribute 18 | . . gene 26245 26472 . + . gene_name "E" 19 | . . gene 26523 27191 . + . gene_name "M" 20 | . . gene 28274 29533 . + . gene_name "N" 21 | . . gene 29558 29674 . + . gene_name "ORF10" 22 | . . gene 28734 28955 . + . gene_name "ORF14" 23 | . . gene 266 13468 . + . gene_name "ORF1a" 24 | . . gene 13468 21555 . + . gene_name "ORF1b" 25 | . . gene 25393 26220 . + . gene_name "ORF3a" 26 | . . gene 27202 27387 . + . gene_name "ORF6" 27 | . . gene 27394 27759 . + . gene_name "ORF7a" 28 | . . gene 27756 27887 . + . gene_name "ORF7b" 29 | . . gene 27894 28259 . + . gene_name "ORF8" 30 | . . gene 28284 28577 . + . gene_name "ORF9b" 31 | . . gene 21563 25384 . + . gene_name "S" 32 | """ 33 | 34 | 35 | def _(): 36 | genes = [] 37 | rows = annotation_tsv.split("\n") 38 | header, rows = rows[0].split("\t"), rows[1:] 39 | for row in rows: 40 | if row: 41 | row = dict(zip(header, row.split("\t"))) 42 | gene_name = row["attribute"].split('"')[1] 43 | start = int(row["start"]) 44 | end = int(row["end"]) 45 | genes.append(((start, end), gene_name)) 46 | genes.sort() 47 | return OrderedDict((gene_name, pos) for pos, gene_name in genes) 48 | 49 | 50 | # This maps gene name to the nucleotide position in the genome, 51 | # as measured in the original Wuhan virus. 52 | GENE_TO_POSITION: Dict[str, Tuple[int, int]] = _() 53 | 54 | # This maps gene name to a set of regions in that gene. 55 | # These regions may be used in plotting e.g. mutrans.ipynb. 56 | # Each region has a string label and an extent (start, end) 57 | # measured in amino acid positions relative to the start. 58 | GENE_STRUCTURE: Dict[str, Dict[str, Tuple[int, int]]] = { 59 | # Source: https://www.nature.com/articles/s41401-020-0485-4/figures/2 60 | "S": { 61 | "NTD": (13, 305), 62 | "RBD": (319, 541), 63 | "FC": (682, 685), 64 | "FP": (788, 806), 65 | "HR1": (912, 984), 66 | "HR2": (1163, 1213), 67 | "TM": (1213, 1237), 68 | "CT": (1237, 1273), 69 | }, 70 | # Source https://www.nature.com/articles/s41467-021-21953-3 71 | "N": { 72 | "NTD": (1, 49), 73 | "RNA binding": (50, 174), 74 | "SR": (175, 215), 75 | "dimerization": (246, 365), 76 | "CTD": (365, 419), 77 | # "immunogenic": (133, 217), 78 | }, 79 | # Source: https://www.ncbi.nlm.nih.gov/protein/YP_009725295.1 80 | "ORF1a": { 81 | "nsp1": (0, 180), # leader protein 82 | "nsp2": (180, 818), 83 | "nsp3": (818, 2763), 84 | "nsp4": (2763, 3263), 85 | "nsp5": (3263, 3569), # 3C-like proteinase 86 | "nsp6": (3569, 3859), 87 | "nsp7": (3859, 3942), 88 | "nsp8": (3942, 4140), 89 | "nsp9": (4140, 4253), 90 | "nsp10": (4253, 4392), 91 | "nsp11": (4392, 4405), 92 | }, 93 | # Source: https://www.ncbi.nlm.nih.gov/protein/1796318597 94 | "ORF1ab": { 95 | "nsp1": (0, 180), # leader protein 96 | "nsp2": (180, 818), 97 | "nsp3": (818, 2763), 98 | "nsp4": (2763, 3263), 99 | "nsp5": (3263, 3569), # 3C-like proteinase 100 | "nsp6": (3569, 3859), 101 | "nsp7": (3859, 3942), 102 | "nsp8": (3942, 4140), 103 | "nsp9": (4140, 4253), 104 | "nsp10": (4253, 4392), 105 | "nsp12": (4392, 5324), # RNA-dependent RNA polymerase 106 | "nsp13": (5324, 5925), # helicase 107 | "nsp14": (5925, 6452), # 3'-to-5' exonuclease 108 | "nsp15": (6452, 6798), # endoRNAse 109 | "nsp16": (6798, 7096), # 2'-O-ribose methyltransferase 110 | }, 111 | # Source: see infer_ORF1b_structure() below. 112 | "ORF1b": { 113 | "nsp12": (0, 924), # RNA-dependent RNA polymerase 114 | "nsp13": (924, 1525), # helicase 115 | "nsp14": (1525, 2052), # 3'-to-5' exonuclease 116 | "nsp15": (2052, 2398), # endoRNAse 117 | "nsp16": (2398, 2696), # 2'-O-ribose methyltransferase 118 | }, 119 | } 120 | 121 | 122 | def load_gene_structure(filename=None, gene_name=None): 123 | """ 124 | Loads structure from a GenPept file for use in GENE_STRUCTURE. 125 | This is used only when updating the static GENE_STRUCTURE dict. 126 | """ 127 | from Bio import SeqIO 128 | 129 | if filename is None: 130 | assert gene_name is not None 131 | filename = f"data/{gene_name}.gp" 132 | result = {} 133 | with open(filename) as f: 134 | for record in SeqIO.parse(f, format="genbank"): 135 | for feature in record.features: 136 | if feature.type == "mat_peptide": 137 | product = feature.qualifiers["product"][0] 138 | assert isinstance(product, str) 139 | start = int(feature.location.start) 140 | end = int(feature.location.end) 141 | result[product] = start, end 142 | return result 143 | 144 | 145 | def infer_ORF1b_structure(): 146 | """ 147 | Infers approximate ORF1b structure from ORF1ab. 148 | This is used only when updating the static GENE_STRUCTURE dict. 149 | """ 150 | ORF1a_start = GENE_TO_POSITION["ORF1a"][0] 151 | ORF1b_start = GENE_TO_POSITION["ORF1b"][0] 152 | shift = (ORF1b_start - ORF1a_start) // 3 153 | result = {} 154 | for name, (start, end) in GENE_STRUCTURE["ORF1ab"].items(): 155 | start -= shift 156 | end -= shift 157 | if end > 0: 158 | start = max(start, 0) 159 | result[name] = start, end 160 | return result 161 | 162 | 163 | def aa_mutation_to_position(m: str) -> int: 164 | """ 165 | E.g. map 'S:N501Y' to 21563 + (501 - 1) * 3 = 23063. 166 | """ 167 | gene_name, subs = m.split(":") 168 | start, end = GENE_TO_POSITION[gene_name] 169 | match = re.search(r"\d+", subs) 170 | assert match is not None 171 | aa_offset = int(match.group(0)) - 1 172 | return start + aa_offset * 3 173 | 174 | 175 | def nuc_mutations_to_aa_mutations(ms: List[str]) -> List[str]: 176 | global REFERENCE_SEQ 177 | if REFERENCE_SEQ is None: 178 | REFERENCE_SEQ = load_reference_sequence() 179 | 180 | ms_by_aa = defaultdict(list) 181 | 182 | for m in ms: 183 | # Parse a nucleotide mutation such as "A1234G" -> (1234, "G"). 184 | # Note this uses 1-based indexing. 185 | if isinstance(m, str): 186 | position_nuc = int(m[1:-1]) 187 | new_nuc = m[-1] 188 | else: 189 | # assert isinstance(m, pyrocov.usher.Mutation) 190 | position_nuc = m.position 191 | new_nuc = m.mut 192 | 193 | # Find the first matching gene. 194 | for gene, (start, end) in GENE_TO_POSITION.items(): 195 | if start <= position_nuc <= end: 196 | position_aa = (position_nuc - start) // 3 197 | position_codon = (position_nuc - start) % 3 198 | ms_by_aa[gene, position_aa].append((position_codon, new_nuc)) 199 | 200 | # Format cumulative amino acid changes. 201 | result = [] 202 | for (gene, position_aa), ms in ms_by_aa.items(): 203 | start, end = GENE_TO_POSITION[gene] 204 | 205 | # Apply mutation to determine new aa. 206 | pos = start + position_aa * 3 207 | pos -= 1 # convert from 1-based to 0-based 208 | old_codon = REFERENCE_SEQ[pos : pos + 3] 209 | new_codon = list(old_codon) 210 | for position_codon, new_nuc in ms: 211 | new_codon[position_codon] = new_nuc 212 | new_codon = "".join(new_codon) 213 | 214 | # Format. 215 | old_aa = DNA_TO_AA[old_codon] 216 | new_aa = DNA_TO_AA[new_codon] 217 | if new_aa == old_aa: # ignore synonymous substitutions 218 | continue 219 | if old_aa is None: 220 | old_aa = "STOP" 221 | if new_aa is None: 222 | new_aa = "STOP" 223 | result.append(f"{gene}:{old_aa}{position_aa + 1}{new_aa}") # 1-based 224 | return result 225 | 226 | 227 | def load_reference_sequence(): 228 | with open(os.path.join(NEXTCLADE_DATA, "reference.fasta")) as f: 229 | ref = "".join(line.strip() for line in f if not line.startswith(">")) 230 | assert len(ref) == 29903, len(ref) 231 | return ref 232 | -------------------------------------------------------------------------------- /pyrocov/sketch.cpp: -------------------------------------------------------------------------------- 1 | // Copyright Contributors to the Pyro-Cov project. 2 | // SPDX-License-Identifier: Apache-2.0 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | inline uint64_t murmur64(uint64_t h) { 12 | h ^= h >> 33; 13 | h *= 0xff51afd7ed558ccd; 14 | h ^= h >> 33; 15 | h *= 0xc4ceb9fe1a85ec53; 16 | h ^= h >> 33; 17 | return h; 18 | } 19 | 20 | at::Tensor get_32mers(const std::string& seq) { 21 | static std::vector to_bits(256, false); 22 | to_bits['A'] = 0; 23 | to_bits['C'] = 1; 24 | to_bits['G'] = 2; 25 | to_bits['T'] = 3; 26 | 27 | int size = seq.size() - 32 + 1; 28 | if (size <= 0) { 29 | return at::empty(0, at::kLong); 30 | } 31 | at::Tensor out = at::empty(size, at::kLong); 32 | int64_t * const out_data = static_cast(out.data_ptr()); 33 | 34 | int64_t kmer = 0; 35 | for (int pos = 0; pos < 31; ++pos) { 36 | kmer <<= 2; 37 | kmer ^= to_bits[seq[pos]]; 38 | } 39 | for (int pos = 0; pos < size; ++pos) { 40 | kmer <<= 2; 41 | kmer ^= to_bits[seq[pos + 31]]; 42 | out_data[pos] = kmer; 43 | } 44 | return out; 45 | } 46 | 47 | struct KmerCounter { 48 | KmerCounter() : to_bits(256, false), counts() { 49 | to_bits['A'] = 0; 50 | to_bits['C'] = 1; 51 | to_bits['G'] = 2; 52 | to_bits['T'] = 3; 53 | } 54 | 55 | void update(const std::string& seq) { 56 | int size = seq.size() - 32 + 1; 57 | if (size <= 0) { 58 | return; 59 | } 60 | 61 | int64_t kmer = 0; 62 | for (int pos = 0; pos < 31; ++pos) { 63 | kmer <<= 2; 64 | kmer ^= to_bits[seq[pos]]; 65 | } 66 | for (int pos = 0; pos < size; ++pos) { 67 | kmer <<= 2; 68 | kmer ^= to_bits[seq[pos + 31]]; 69 | counts[kmer] += 1; 70 | } 71 | } 72 | 73 | void truncate_below(int64_t threshold) { 74 | for (auto i = counts.begin(); i != counts.end();) { 75 | if (i->second < threshold) { 76 | counts.erase(i++); 77 | } else { 78 | ++i; 79 | } 80 | } 81 | } 82 | 83 | std::unordered_map to_dict() const { return counts; } 84 | 85 | std::vector to_bits; 86 | std::unordered_map counts; 87 | }; 88 | 89 | void string_to_soft_hash(int min_k, int max_k, const std::string& seq, at::Tensor out) { 90 | static std::vector to_bits(256, false); 91 | to_bits['A'] = 0; 92 | to_bits['C'] = 1; 93 | to_bits['G'] = 2; 94 | to_bits['T'] = 3; 95 | 96 | static std::vector salts(33); 97 | assert(max_k < salts.size()); 98 | for (int k = min_k; k <= max_k; ++k) { 99 | salts[k] = murmur64(1 + k); 100 | } 101 | 102 | float * const data_begin = static_cast(out.data_ptr()); 103 | float * const data_end = data_begin + out.size(-1); 104 | for (int pos = 0, end = seq.size(); pos != end; ++pos) { 105 | if (max_k > end - pos) { 106 | max_k = end - pos; 107 | } 108 | uint64_t hash = 0; 109 | for (int k = 1; k <= max_k; ++k) { 110 | int i = k - 1; 111 | hash ^= to_bits[seq[pos + i]] << (i + i); 112 | if (k < min_k) continue; 113 | uint64_t hash_k = murmur64(salts[k] ^ hash); 114 | for (float *p = data_begin; p != data_end; ++p) { 115 | *p += static_cast(hash_k & 1UL) * 2L - 1L; 116 | hash_k >>= 1UL; 117 | } 118 | } 119 | } 120 | } 121 | 122 | void string_to_clock_hash_v0(int k, const std::string& seq, at::Tensor clocks, at::Tensor count) { 123 | static std::vector to_bits(256, false); 124 | to_bits['A'] = 0; 125 | to_bits['C'] = 1; 126 | to_bits['G'] = 2; 127 | to_bits['T'] = 3; 128 | 129 | int8_t * const clocks_data = static_cast(clocks.data_ptr()); 130 | const int num_kmers = seq.size() - k + 1; 131 | count.add_(num_kmers); 132 | for (int pos = 0; pos < num_kmers; ++pos) { 133 | uint64_t hash = murmur64(1 + k); 134 | for (int i = 0; i < k; ++i) { 135 | hash ^= to_bits[seq[pos + i]] << (i + i); 136 | } 137 | hash = murmur64(hash); 138 | for (int i = 0; i < 64; ++i) { 139 | clocks_data[i] += (hash >> i) & 1; 140 | } 141 | } 142 | } 143 | 144 | void string_to_clock_hash(int k, const std::string& seq, at::Tensor clocks, at::Tensor count) { 145 | static std::vector to_bits(256, false); 146 | to_bits['A'] = 0; 147 | to_bits['C'] = 1; 148 | to_bits['G'] = 2; 149 | to_bits['T'] = 3; 150 | 151 | const int num_words = clocks.size(-1) / 64; 152 | uint64_t * const clocks_data = static_cast(clocks.data_ptr()); 153 | const int num_kmers = seq.size() - k + 1; 154 | count.add_(num_kmers); 155 | for (int pos = 0; pos < num_kmers; ++pos) { 156 | uint64_t hash = 0; 157 | for (int i = 0; i != k; ++i) { 158 | hash ^= to_bits[seq[pos + i]] << (i + i); 159 | } 160 | for (int w = 0; w != num_words; ++w) { 161 | const uint64_t hash_w = murmur64(murmur64(1 + w) ^ hash); 162 | for (int b = 0; b != 8; ++b) { 163 | const int wb = w * 8 + b; 164 | clocks_data[wb] = (clocks_data[wb] + ((hash_w >> b) & 0x0101010101010101UL)) 165 | & 0x7F7F7F7F7F7F7F7FUL; 166 | } 167 | } 168 | } 169 | } 170 | 171 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 172 | m.def("get_32mers", &get_32mers, "Extract list of 32-mers from a string"); 173 | m.def("string_to_soft_hash", &string_to_soft_hash, "Convert a string to a soft hash"); 174 | m.def("string_to_clock_hash", &string_to_clock_hash, "Convert a string to a clock hash"); 175 | py::class_(m, "KmerCounter") 176 | .def(py::init<>()) 177 | .def("update", &KmerCounter::update) 178 | .def("truncate_below", &KmerCounter::truncate_below) 179 | .def("to_dict", &KmerCounter::to_dict); 180 | } 181 | -------------------------------------------------------------------------------- /pyrocov/softmax_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from collections import defaultdict 5 | 6 | import pyro.distributions as dist 7 | import torch 8 | 9 | from pyrocov.phylo import Phylogeny 10 | 11 | 12 | class SoftmaxTree(dist.Distribution): 13 | """ 14 | Samples a :class:`~pyrocov.phylo.Phylogeny` given parameters of a tree 15 | embedding. 16 | 17 | :param torch.Tensor bit_times: Tensor of times of each bit in the 18 | embedding. 19 | :param torch.Tensor logits: ``(num_leaves, num_bits)``-shaped tensor 20 | parametrizing the independent Bernoulli distributions over bits 21 | in each leaf's embedding. 22 | """ 23 | 24 | has_rsample = True # only wrt times, not parents 25 | 26 | def __init__(self, leaf_times, bit_times, logits): 27 | assert leaf_times.dim() == 1 28 | assert bit_times.dim() == 1 29 | assert logits.dim() == 2 30 | assert logits.shape == leaf_times.shape + bit_times.shape 31 | self.leaf_times = leaf_times 32 | self.bit_times = bit_times 33 | self._bernoulli = dist.Bernoulli(logits=logits) 34 | super().__init__() 35 | 36 | @property 37 | def probs(self): 38 | return self._bernoulli.probs 39 | 40 | @property 41 | def logits(self): 42 | return self._bernoulli.logits 43 | 44 | @property 45 | def num_leaves(self): 46 | return self.logits.size(0) 47 | 48 | @property 49 | def num_bits(self): 50 | return self.logits.size(1) 51 | 52 | def entropy(self): 53 | return self._bernoulli.entropy().sum([-1, -2]) 54 | 55 | def sample(self, sample_shape=torch.Size()): 56 | if sample_shape: 57 | raise NotImplementedError 58 | raise NotImplementedError("TODO") 59 | 60 | def rsample(self, sample_shape=torch.Size()): 61 | if sample_shape: 62 | raise NotImplementedError 63 | bits = self._bernoulli.sample() 64 | num_leaves, num_bits = bits.shape 65 | phylogeny = _decode(self.leaf_times, self.bit_times, bits, self.probs) 66 | return phylogeny 67 | 68 | def log_prob(self, phylogeny): 69 | """ 70 | :param ~pyrocov.phylo.Phylogeny phylogeny: 71 | """ 72 | return self.entropy() # Is this right? 73 | 74 | 75 | # TODO Implement a C++ version. 76 | # This costs O(num_bits * num_leaves) sequential time. 77 | def _decode(leaf_times, bit_times, bits, probs): 78 | # Sort bits by time. 79 | bit_times, index = bit_times.sort() 80 | bits = bits[..., index] 81 | probs = probs[..., index] 82 | 83 | # Construct internal nodes. 84 | num_leaves, num_bits = bits.shape 85 | assert num_leaves >= 2 86 | times = torch.cat([leaf_times, leaf_times.new_empty(num_leaves - 1)]) 87 | parents = torch.empty(2 * num_leaves - 1, dtype=torch.long) 88 | leaves = torch.arange(num_leaves) 89 | 90 | next_id = num_leaves 91 | 92 | def get_id(): 93 | nonlocal next_id 94 | next_id += 1 95 | return next_id - 1 96 | 97 | root = get_id() 98 | parents[root] = -1 99 | partitions = [{frozenset(range(num_leaves)): root}] 100 | for t, b in zip(*bit_times.sort()): 101 | partitions.append({}) 102 | for partition, p in partitions[-2].items(): 103 | children = defaultdict(set) 104 | for n in partition: 105 | bit = bits[n, b].item() 106 | # TODO Clamp bit if t is later than node n. 107 | children[bit].add(n) 108 | if len(children) == 1: 109 | partitions[-1][partition] = p 110 | continue 111 | assert len(children) == 2 112 | for child in children.values(): 113 | if len(child) == 1: 114 | # Terminate at a leaf. 115 | c = child.pop() 116 | else: 117 | # Create a new internal node. 118 | c = get_id() 119 | partitions[-1][frozenset(child)] = c 120 | parents[c] = p 121 | times[p] = t 122 | # Create binarized fans for remaining leaves. 123 | for partition, p in partitions[-1].items(): 124 | t = times[torch.tensor(list(partition))].min() 125 | times[p] = t 126 | partition = set(partition) 127 | while len(partition) > 2: 128 | c = get_id() 129 | times[c] = t 130 | parents[c] = p 131 | parents[partition.pop()] = p 132 | p = c 133 | parents[partition.pop()] = p 134 | parents[partition.pop()] = p 135 | assert not partition 136 | 137 | return Phylogeny.from_unsorted(times, parents, leaves) 138 | -------------------------------------------------------------------------------- /pyrocov/special.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # Adapted from @viswackftw 5 | # https://github.com/pytorch/pytorch/issues/52973#issuecomment-787587188 6 | 7 | import math 8 | 9 | import torch 10 | 11 | 12 | def ndtr(value: torch.Tensor): 13 | """ 14 | Based on the SciPy implementation of ndtr 15 | """ 16 | sqrt_half = torch.sqrt(torch.tensor(0.5, dtype=value.dtype)) 17 | x = value * sqrt_half 18 | z = abs(x) 19 | y = 0.5 * torch.erfc(z) 20 | output = torch.where( 21 | z < sqrt_half, 0.5 + 0.5 * torch.erf(x), torch.where(x > 0, 1 - y, y) 22 | ) 23 | return output 24 | 25 | 26 | def log_ndtr(value: torch.Tensor): 27 | """ 28 | Function to compute the log of the normal CDF at value. 29 | This is based on the TFP implementation. 30 | """ 31 | dtype = value.dtype 32 | if dtype == torch.float64: 33 | lower, upper = -20, 8 34 | elif dtype == torch.float32: 35 | lower, upper = -10, 5 36 | else: 37 | raise TypeError("value needs to be either float32 or float64") 38 | 39 | # When x < lower, then we perform a fixed series expansion (asymptotic) 40 | # = log(cdf(x)) = log(1 - cdf(-x)) = log(1 / 2 * erfc(-x / sqrt(2))) 41 | # = log(-1 / sqrt(2 * pi) * exp(-x ** 2 / 2) / x * (1 + sum)) 42 | # When x >= lower and x <= upper, then we simply perform log(cdf(x)) 43 | # When x > upper, then we use the approximation log(cdf(x)) = log(1 - cdf(-x)) \approx -cdf(-x) 44 | return torch.where( 45 | value > upper, 46 | torch.log1p(-ndtr(-value)), 47 | torch.where(value >= lower, torch.log(ndtr(value)), log_ndtr_series(value)), 48 | ) 49 | 50 | 51 | def log_ndtr_series(value: torch.Tensor, num_terms=3): 52 | """ 53 | Function to compute the asymptotic series expansion of the log of normal CDF 54 | at value. 55 | This is based on the TFP implementation. 56 | """ 57 | # sum = sum_{n=1}^{num_terms} (-1)^{n} (2n - 1)!! / x^{2n})) 58 | value_sq = value**2 59 | t1 = -0.5 * (math.log(2 * math.pi) + value_sq) - torch.log(-value) 60 | t2 = torch.zeros_like(value) 61 | value_even_power = value_sq.clone() 62 | double_fac = 1 63 | multiplier = -1 64 | for n in range(1, num_terms + 1): 65 | t2.add_(multiplier * double_fac / value_even_power) 66 | value_even_power.mul_(value_sq) 67 | double_fac *= 2 * n - 1 68 | multiplier *= -1 69 | return t1 + torch.log1p(t2) 70 | -------------------------------------------------------------------------------- /pyrocov/stats.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.special import log_ndtr 7 | 8 | 9 | def hpd_interval(p: float, samples: torch.Tensor): 10 | assert 0.5 < p < 1 11 | assert samples.shape 12 | pad = int(round((1 - p) * len(samples))) 13 | assert pad > 0, "too few samples" 14 | width = samples[-pad:] - samples[:pad] 15 | lb = width.max(0).indices 16 | ub = len(samples) - lb - 1 17 | i = torch.stack([lb, ub]) 18 | return samples.gather(0, i) 19 | 20 | 21 | def confidence_interval(p: float, samples: torch.Tensor): 22 | assert 0.5 < p < 1 23 | assert samples.shape 24 | pad = (1 - p) / 2 25 | lk = int(round(pad * (len(samples) - 1))) 26 | uk = int(round((1 - pad) * (len(samples) - 1))) 27 | assert pad > 0, "too few samples" 28 | lb = samples.kthvalue(lk, 0).values 29 | ub = samples.kthvalue(uk, 0).values 30 | return torch.stack([lb, ub]) 31 | 32 | 33 | def normal_log10bf(mean, std=1.0): 34 | """ 35 | Returns ``log10(P[x>0] / P[x<0])`` for ``x ~ N(mean, std)``. 36 | """ 37 | z = mean / std 38 | return (log_ndtr(z) - log_ndtr(-z)) / np.log(10) 39 | -------------------------------------------------------------------------------- /pyrocov/substitution.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pyro.distributions as dist 5 | import torch 6 | from pyro.nn import PyroModule, PyroSample, pyro_method 7 | 8 | 9 | class SubstitutionModel(PyroModule): 10 | """ 11 | Probabilistic substitution model among a finite number of states 12 | (typically 4 for nucleotides or 20 for amino acids). 13 | 14 | This returns a continuous time transition matrix. 15 | """ 16 | 17 | @pyro_method 18 | def matrix_exp(self, dt): 19 | m = self().to(dt.dtype) 20 | return (m * dt[:, None, None]).matrix_exp() 21 | 22 | @pyro_method 23 | def log_matrix_exp(self, dt): 24 | m = self.matrix_exp(dt) 25 | m.data.clamp_(torch.finfo(m.dtype).eps) 26 | return m.log() 27 | 28 | 29 | class JukesCantor69(SubstitutionModel): 30 | """ 31 | A simple uniform substition model with a single latent rate parameter. 32 | 33 | This provides a weak Exponential(1) prior over the rate parameter. 34 | 35 | [1] T.H. Jukes, C.R. Cantor (1969) "Evolution of protein molecules" 36 | [2] https://en.wikipedia.org/wiki/Models_of_DNA_evolution#JC69_model_(Jukes_and_Cantor_1969) 37 | """ 38 | 39 | def __init__(self, *, dim=4): 40 | assert isinstance(dim, int) and dim > 0 41 | super().__init__() 42 | self.dim = dim 43 | self.rate = PyroSample(dist.Exponential(1.0)) 44 | 45 | def forward(self): 46 | D = self.dim 47 | return self.rate * (1.0 / D - torch.eye(D)) 48 | 49 | @pyro_method 50 | def matrix_exp(self, dt): 51 | D = self.dim 52 | rate = torch.as_tensor(self.rate, dtype=dt.dtype) 53 | p = dt.mul(-rate).exp()[:, None, None] 54 | q = (1 - p) / D 55 | return torch.where(torch.eye(D, dtype=torch.bool), p + q, q) 56 | 57 | @pyro_method 58 | def log_matrix_exp(self, dt): 59 | D = self.dim 60 | rate = torch.as_tensor(self.rate, dtype=dt.dtype) 61 | p = dt.mul(-rate).exp()[:, None, None] 62 | q = (1 - p) / D 63 | q.data.clamp_(min=torch.finfo(q.dtype).eps) 64 | on_diag = (p + q).log() 65 | off_diag = q.log() 66 | return torch.where(torch.eye(D, dtype=torch.bool), on_diag, off_diag) 67 | 68 | 69 | class GeneralizedTimeReversible(SubstitutionModel): 70 | """ 71 | Generalized time-reversible substitution model among ``dim``-many states. 72 | 73 | This provides a weak Dirichlet(2) prior over the steady state distribution 74 | and a weak Exponential(1) prior over mutation rates. 75 | """ 76 | 77 | def __init__(self, *, dim=4): 78 | assert isinstance(dim, int) and dim > 0 79 | super().__init__() 80 | self.dim = dim 81 | self.stationary = PyroSample(dist.Dirichlet(torch.full((dim,), 2.0))) 82 | self.rates = PyroSample( 83 | dist.Exponential(torch.ones(dim * (dim - 1) // 2)).to_event(1) 84 | ) 85 | i = torch.arange(dim) 86 | self._index = (i > i[:, None]).nonzero(as_tuple=False).T 87 | 88 | def forward(self): 89 | p = self.stationary 90 | i, j = self._index 91 | m = torch.zeros(self.dim, self.dim) 92 | m[i, j] = self.rates 93 | m = m + m.T * (p / p[:, None]) 94 | m = m - m.sum(dim=-1).diag_embed() 95 | return m 96 | -------------------------------------------------------------------------------- /pyrocov/util.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import functools 5 | import gzip 6 | import itertools 7 | import operator 8 | import os 9 | import weakref 10 | from typing import Dict 11 | 12 | import pyro 13 | import torch 14 | import tqdm 15 | from torch.distributions import constraints, transform_to 16 | 17 | 18 | def pearson_correlation(x: torch.Tensor, y: torch.Tensor): 19 | x = (x - x.mean()) / x.std() 20 | y = (y - x.mean()) / y.std() 21 | return (x * y).mean() 22 | 23 | 24 | def pyro_param(name, shape, constraint=constraints.real): 25 | transform = transform_to(constraint) 26 | terms = [] 27 | for subshape in itertools.product(*({1, int(size)} for size in shape)): 28 | subname = "_".join([name] + list(map(str, subshape))) 29 | subinit = functools.partial(torch.zeros, subshape) 30 | terms.append(pyro.param(subname, subinit)) 31 | unconstrained = functools.reduce(operator.add, terms) 32 | return transform(unconstrained) 33 | 34 | 35 | def quotient_central_moments( 36 | fine_values: torch.Tensor, fine_to_coarse: torch.Tensor 37 | ) -> torch.Tensor: 38 | """ 39 | Returns (zeroth, first, second) central momemnts of each coarse cluster of 40 | fine values, i.e. (count, mean, stddev). 41 | 42 | :returns: A single stacked tensor of shape ``(3,) + fine_values.shape``. 43 | """ 44 | C = 1 + int(fine_to_coarse.max()) 45 | moments = torch.zeros(3, C) 46 | moments[0].scatter_add_(0, fine_to_coarse, torch.ones_like(fine_values)) 47 | moments[1].scatter_add_(0, fine_to_coarse, fine_values) 48 | moments[1] /= moments[0] 49 | fine_diff2 = (fine_values - moments[1][fine_to_coarse]).square() 50 | moments[2].scatter_add_(0, fine_to_coarse, fine_diff2) 51 | moments[2] /= moments[0] 52 | moments[2].sqrt_() 53 | return moments 54 | 55 | 56 | def weak_memoize_by_id(fn): 57 | cache = {} 58 | missing = object() # An arbitrary value that cannot be returned by fn. 59 | 60 | @functools.wraps(fn) 61 | def memoized_fn(*args): 62 | key = tuple(map(id, args)) 63 | result = cache.get(key, missing) 64 | if result is missing: 65 | result = cache[key] = fn(*args) 66 | for arg in args: 67 | # Register callbacks only for types that support weakref. 68 | if type(arg).__weakrefoffset__: 69 | weakref.finalize(arg, cache.pop, key, None) 70 | return result 71 | 72 | return memoized_fn 73 | 74 | 75 | _TENSORS: Dict[tuple, torch.Tensor] = {} 76 | 77 | 78 | def deduplicate_tensor(x): 79 | key = x.dtype, x.stride(), x.data_ptr() 80 | return _TENSORS.setdefault(key, x) 81 | 82 | 83 | def torch_map(x, **kwargs): 84 | """ 85 | Calls ``leaf.to(**kwargs)`` on all tensor and module leaves of a nested 86 | data structure. 87 | """ 88 | return _torch_map(x, **kwargs)[0] 89 | 90 | 91 | @functools.singledispatch 92 | def _torch_map(x, **kwargs): 93 | return x, False 94 | 95 | 96 | @_torch_map.register(torch.Tensor) 97 | def _torch_map_tensor(x, **kwargs): 98 | x_ = x.to(**kwargs) 99 | changed = x_ is not x 100 | return x_, changed 101 | 102 | 103 | @_torch_map.register(torch.nn.Module) 104 | def _torch_map_module(x, **kwargs): 105 | changed = True # safe 106 | return x.to(**kwargs), changed 107 | 108 | 109 | @_torch_map.register(dict) 110 | def _torch_map_dict(x, **kwargs): 111 | result = type(x)() 112 | changed = False 113 | for k, v in x.items(): 114 | v, v_changed = _torch_map(v, **kwargs) 115 | result[k] = v 116 | changed = changed or v_changed 117 | return (result, True) if changed else (x, False) 118 | 119 | 120 | @_torch_map.register(list) 121 | @_torch_map.register(tuple) 122 | def _torch_map_iterable(x, **kwargs): 123 | result = [] 124 | changed = False 125 | for v in x: 126 | v, v_changed = _torch_map(v, **kwargs) 127 | result.append(v) 128 | changed = changed or v_changed 129 | result = type(x)(result) 130 | return (result, True) if changed else (x, False) 131 | 132 | 133 | def pretty_print(x, *, name="", max_items=10): 134 | if isinstance(x, (int, float, str, bool)): 135 | print(f"{name} = {repr(x)}") 136 | elif isinstance(x, torch.Tensor): 137 | print(f"{name}: {type(x).__name__} of shape {tuple(x.shape)}") 138 | elif isinstance(x, (tuple, list)): 139 | print(f"{name}: {type(x).__name__} of length {len(x)}") 140 | elif isinstance(x, dict): 141 | print(f"{name}: {type(x).__name__} of length {len(x)}") 142 | if len(x) <= max_items: 143 | for k, v in x.items(): 144 | pretty_print(v, name=f"{name}[{repr(k)}]", max_items=max_items) 145 | else: 146 | print(f"{name}: {type(x).__name__}") 147 | 148 | 149 | def generate_colors(num_points=100, lb=0.5, ub=2.5): 150 | """ 151 | Constructs a quasirandom collection of colors for plotting. 152 | """ 153 | # http://extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences/ 154 | phi3 = 1.2207440846 155 | alpha = torch.tensor([1 / phi3**3, 1 / phi3**2, 1 / phi3]) 156 | t = torch.arange(float(2 * num_points)) 157 | rgb = alpha.mul(t[:, None]).add(torch.tensor([0.8, 0.2, 0.1])).fmod(1) 158 | total = rgb.sum(-1) 159 | rgb = rgb[(lb <= total) & (total <= ub)] 160 | rgb = rgb[:num_points] 161 | assert len(rgb) == num_points 162 | return [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in rgb.mul(256).long().tolist()] 163 | 164 | 165 | def open_tqdm(*args, **kwargs): 166 | with open(*args, **kwargs) as f: 167 | with tqdm.tqdm( 168 | total=os.stat(f.fileno()).st_size, 169 | unit="B", 170 | unit_scale=True, 171 | unit_divisor=1024, 172 | smoothing=0, 173 | ) as pbar: 174 | for line in f: 175 | pbar.update(len(line)) 176 | yield line 177 | 178 | 179 | def gzip_open_tqdm(filename, mode="rb"): 180 | with open(filename, "rb") as f, gzip.open(f, mode) as g: 181 | with tqdm.tqdm( 182 | total=os.stat(f.fileno()).st_size, 183 | unit="B", 184 | unit_scale=True, 185 | unit_divisor=1024, 186 | smoothing=0, 187 | ) as pbar: 188 | for line in g: 189 | pbar.n = f.tell() 190 | pbar.update(0) 191 | yield line 192 | -------------------------------------------------------------------------------- /scripts/fix_columns.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import glob 5 | import logging 6 | import os 7 | import pickle 8 | 9 | import tqdm 10 | 11 | from pyrocov.geo import gisaid_normalize 12 | 13 | logger = logging.getLogger(__name__) 14 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO) 15 | 16 | 17 | def main(): 18 | """ 19 | Fixes columns["location"] via gisaid_normalize(). 20 | """ 21 | tempfile = "results/temp.columns.pkl" 22 | for infile in glob.glob("results/*.columns.pkl"): 23 | if "temp" in infile: 24 | continue 25 | logger.info(f"Processing {infile}") 26 | with open(infile, "rb") as f: 27 | columns = pickle.load(f) 28 | columns["location"] = [ 29 | gisaid_normalize(x) for x in tqdm.tqdm(columns["location"]) 30 | ] 31 | with open(tempfile, "wb") as f: 32 | pickle.dump(columns, f) 33 | os.rename(tempfile, infile) # atomic 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /scripts/git_pull.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | from subprocess import check_call 7 | 8 | # This keeps repos organized as ~/github/{user}/{repo} 9 | GITHUB = os.path.expanduser(os.path.join("~", "github")) 10 | if not os.path.exists(GITHUB): 11 | os.makedirs(GITHUB) 12 | 13 | update = True 14 | for arg in sys.argv[1:]: 15 | if arg == "--no-update": 16 | update = False 17 | continue 18 | 19 | try: 20 | user, repo = arg.split("/") 21 | except Exception: 22 | raise ValueError( 23 | f"Expected args of the form username/repo e.g. pyro-ppl/pyro, but got {arg}" 24 | ) 25 | 26 | dirname = os.path.join(GITHUB, user) 27 | if not os.path.exists(dirname): 28 | os.makedirs(dirname) 29 | os.chdir(dirname) 30 | dirname = os.path.join(dirname, repo) 31 | if not os.path.exists(dirname): 32 | print(f"Cloning {arg}") 33 | check_call( 34 | ["git", "clone", "--depth", "1", f"https://github.com/{user}/{repo}"] 35 | ) 36 | elif update: 37 | print(f"Pulling {arg}") 38 | os.chdir(dirname) 39 | check_call(["git", "pull"]) 40 | -------------------------------------------------------------------------------- /scripts/moran.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | 11 | from pyrocov.sarscov2 import aa_mutation_to_position 12 | 13 | 14 | # compute moran statistic 15 | def moran(values, distances, lengthscale): 16 | assert values.size(-1) == distances.size(-1) 17 | weights = (distances.unsqueeze(-1) - distances.unsqueeze(-2)) / lengthscale 18 | weights = torch.exp(-weights.pow(2.0)) 19 | weights *= 1.0 - torch.eye(weights.size(-1)) 20 | weights /= weights.sum(-1, keepdim=True) 21 | 22 | output = torch.einsum("...ij,...i,...j->...", weights, values, values) 23 | return output / values.pow(2.0).sum(-1) 24 | 25 | 26 | # compute moran statistic and do permutation test with given number of permutations 27 | def permutation_test(values, distances, lengthscale, num_perm=999): 28 | values = values - values.mean() 29 | moran_given = moran(values, distances, lengthscale).item() 30 | idx = [torch.randperm(distances.size(-1)) for _ in range(num_perm)] 31 | idx = torch.stack(idx) 32 | moran_perm = moran(values[idx], distances, lengthscale) 33 | p_value = (moran_perm >= moran_given).sum().item() + 1 34 | p_value /= float(num_perm + 1) 35 | return moran_given, p_value 36 | 37 | 38 | def main(args): 39 | # read in inferred mutations 40 | df = pd.read_csv("paper/mutations.tsv", sep="\t", index_col=0) 41 | df = df[["mutation", "Δ log R"]] 42 | mutations = df.values[:, 0] 43 | assert mutations.shape == (2904,) 44 | coefficients = df.values[:, 1] if not args.magnitude else np.abs(df.values[:, 1]) 45 | gene_map = defaultdict(list) 46 | distance_map = defaultdict(list) 47 | 48 | results = [] 49 | 50 | # collect indices and nucleotide positions corresponding to each mutation 51 | for i, m in enumerate(mutations): 52 | gene = m.split(":")[0] 53 | gene_map[gene].append(i) 54 | distance_map[gene].append(aa_mutation_to_position(m)) 55 | 56 | # map over each gene 57 | for gene, idx in gene_map.items(): 58 | values = torch.from_numpy(np.array(coefficients[idx], dtype=np.float32)) 59 | distances = distance_map[gene] 60 | distances = torch.from_numpy(np.array(distances) - min(distances)) 61 | gene_size = distances.max().item() 62 | lengthscale = min(gene_size / 20, 50.0) 63 | _, p_value = permutation_test(values, distances, lengthscale, num_perm=999999) 64 | s = "Gene: {} \t #Mut: {} Size: {} \t p-value: {:.6f} Lengthscale: {:.1f}" 65 | print(s.format(gene, distances.size(0), gene_size, p_value, lengthscale)) 66 | results.append([distances.size(0), gene_size, p_value, lengthscale]) 67 | 68 | # compute moran statistic for entire genome for mulitple lengthscales 69 | for global_lengthscale in [100.0, 500.0]: 70 | distances_ = [aa_mutation_to_position(m) for m in mutations] 71 | distances = torch.from_numpy( 72 | np.array(distances_, dtype=np.float32) - min(distances_) 73 | ) 74 | values = torch.tensor(np.array(coefficients, dtype=np.float32)).float() 75 | _, p_value = permutation_test( 76 | values, distances, global_lengthscale, num_perm=999999 77 | ) 78 | genome_size = distances.max().item() 79 | s = "Entire Genome (#Mut = {}; Size = {}): \t p-value: {:.6f} Lengthscale: {:.1f}" 80 | print(s.format(distances.size(0), genome_size, p_value, global_lengthscale)) 81 | results.append([distances.size(0), genome_size, p_value, global_lengthscale]) 82 | 83 | # save results as csv 84 | results = np.stack(results) 85 | columns = ["NumMutations", "GeneSize", "PValue", "Lengthscale"] 86 | index = list(gene_map.keys()) + ["EntireGenome"] * 2 87 | result = pd.DataFrame(data=results, index=index, columns=columns) 88 | result.sort_values(["PValue"]).to_csv("paper/moran.csv") 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser(description="Compute moran statistics") 93 | parser.add_argument("--magnitude", action="store_true") 94 | args = parser.parse_args() 95 | main(args) 96 | -------------------------------------------------------------------------------- /scripts/preprocess_credits.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import logging 6 | import lzma 7 | import pickle 8 | 9 | logger = logging.getLogger(__name__) 10 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO) 11 | 12 | 13 | def main(args): 14 | logger.info(f"Loading {args.columns_file_in}") 15 | with open(args.columns_file_in, "rb") as f: 16 | columns = pickle.load(f) 17 | 18 | logger.info(f"Saving {args.credits_file_out}") 19 | with lzma.open(args.credits_file_out, "wt") as f: 20 | f.write("\n".join(sorted(columns["index"]))) 21 | 22 | 23 | if __name__ == "__main__": 24 | parser = argparse.ArgumentParser(description="Save gisaid accession numbers") 25 | parser.add_argument("---columns-file-in", default="results/columns.3000.pkl") 26 | parser.add_argument("--credits-file-out", default="paper/accession_ids.txt.xz") 27 | args = parser.parse_args() 28 | main(args) 29 | -------------------------------------------------------------------------------- /scripts/preprocess_gisaid.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import datetime 6 | import json 7 | import logging 8 | import os 9 | import pickle 10 | import warnings 11 | from collections import Counter, defaultdict 12 | 13 | from pyrocov import pangolin 14 | from pyrocov.geo import gisaid_normalize 15 | from pyrocov.mutrans import START_DATE 16 | from pyrocov.util import open_tqdm 17 | 18 | logger = logging.getLogger(__name__) 19 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO) 20 | 21 | DATE_FORMATS = {4: "%Y", 7: "%Y-%m", 10: "%Y-%m-%d"} 22 | 23 | 24 | def parse_date(string): 25 | fmt = DATE_FORMATS.get(len(string)) 26 | if fmt is None: 27 | # Attempt to fix poorly formated dates like 2020-09-1. 28 | parts = string.split("-") 29 | parts = parts[:1] + [f"{int(p):>02d}" for p in parts[1:]] 30 | string = "-".join(parts) 31 | fmt = DATE_FORMATS[len(string)] 32 | return datetime.datetime.strptime(string, fmt) 33 | 34 | 35 | FIELDS = ["virus_name", "accession_id", "collection_date", "location", "add_location"] 36 | 37 | 38 | def main(args): 39 | logger.info(f"Filtering {args.gisaid_file_in}") 40 | if not os.path.exists(args.gisaid_file_in): 41 | raise OSError(f"Missing {args.gisaid_file_in}; you may need to request a feed") 42 | os.makedirs("results", exist_ok=True) 43 | 44 | columns = defaultdict(list) 45 | stats = defaultdict(Counter) 46 | covv_fields = ["covv_" + key for key in FIELDS] 47 | 48 | for i, line in enumerate(open_tqdm(args.gisaid_file_in)): 49 | # Optimize for faster reading. 50 | line, _ = line.split(', "sequence": ', 1) 51 | line += "}" 52 | 53 | # Filter out bad data. 54 | datum = json.loads(line) 55 | if len(datum["covv_collection_date"]) < 7: 56 | continue # Drop rows with no month information. 57 | date = parse_date(datum["covv_collection_date"]) 58 | if date < args.start_date: 59 | date = args.start_date # Clip rows before start date. 60 | lineage = datum["covv_lineage"] 61 | if lineage in (None, "None", ""): 62 | continue # Drop rows with unknown lineage. 63 | try: 64 | lineage = pangolin.compress(lineage) 65 | lineage = pangolin.decompress(lineage) 66 | assert lineage 67 | except (ValueError, AssertionError) as e: 68 | warnings.warn(str(e)) 69 | continue 70 | 71 | # Fix duplicate locations. 72 | datum["covv_location"] = gisaid_normalize(datum["covv_location"]) 73 | 74 | # Collate. 75 | columns["lineage"].append(lineage) 76 | for covv_key, key in zip(covv_fields, FIELDS): 77 | columns[key].append(datum[covv_key]) 78 | columns["day"].append((date - args.start_date).days) 79 | 80 | # Aggregate statistics. 81 | stats["date"][datum["covv_collection_date"]] += 1 82 | stats["location"][datum["covv_location"]] += 1 83 | stats["lineage"][lineage] += 1 84 | 85 | if i >= args.truncate: 86 | break 87 | 88 | num_dropped = i + 1 - len(columns["day"]) 89 | logger.info(f"dropped {num_dropped}/{i+1} = {num_dropped*100/(i+1):0.2g}% rows") 90 | 91 | logger.info(f"saving {args.columns_file_out}") 92 | with open(args.columns_file_out, "wb") as f: 93 | pickle.dump(dict(columns), f) 94 | 95 | logger.info(f"saving {args.stats_file_out}") 96 | with open(args.stats_file_out, "wb") as f: 97 | pickle.dump(dict(stats), f) 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser(description="Preprocess GISAID data") 102 | parser.add_argument("--gisaid-file-in", default="results/gisaid.json") 103 | parser.add_argument("--columns-file-out", default="results/gisaid.columns.pkl") 104 | parser.add_argument("--stats-file-out", default="results/gisaid.stats.pkl") 105 | parser.add_argument("--start-date", default=START_DATE) 106 | parser.add_argument("--truncate", default=int(1e10), type=int) 107 | args = parser.parse_args() 108 | args.start_date = parse_date(args.start_date) 109 | main(args) 110 | -------------------------------------------------------------------------------- /scripts/preprocess_nextstrain.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | """ 5 | Preprocess Nextstrain open data. 6 | 7 | This script aggregates the metadata.tsv.gz file available from: 8 | https://docs.nextstrain.org/projects/ncov/en/latest/reference/remote_inputs.html 9 | This file is mirrored on S3 and GCP: 10 | https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz 11 | s3://nextstrain-data/files/ncov/open/metadata.tsv.gz 12 | gs://nextstrain-data/files/ncov/open/metadata.tsv.gz 13 | """ 14 | 15 | import argparse 16 | import datetime 17 | import logging 18 | import pickle 19 | from collections import Counter, defaultdict 20 | 21 | import torch 22 | 23 | from pyrocov.growth import START_DATE, dense_to_sparse 24 | from pyrocov.util import gzip_open_tqdm 25 | 26 | logger = logging.getLogger(__name__) 27 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO) 28 | 29 | 30 | def parse_date(string): 31 | return datetime.datetime.strptime(string, "%Y-%m-%d") 32 | 33 | 34 | def coarsen_locations(args, counts): 35 | """ 36 | Select regions that have at least ``args.min_region_size`` samples. 37 | Remaining regions will be coarsely aggregated up to country level. 38 | """ 39 | locations = set() 40 | coarsen_location = {} 41 | for location, count in counts.items(): 42 | if " / " in location and counts[location] < args.min_region_size: 43 | old = location 44 | location = location.split(" / ")[0] 45 | coarsen_location[old] = location 46 | locations.add(location) 47 | locations = sorted(locations) 48 | logger.info(f"kept {len(locations)}/{len(counts)} locations") 49 | return locations, coarsen_location 50 | 51 | 52 | def main(args): 53 | columns = defaultdict(list) 54 | stats = defaultdict(Counter) 55 | skipped = Counter() 56 | 57 | # Process rows one at a time. 58 | logger.info(f"Reading {args.metadata_file_in}") 59 | header = None 60 | for line in gzip_open_tqdm(args.metadata_file_in, "rt"): 61 | line = line.strip().split("\t") 62 | if header is None: 63 | header = line 64 | continue 65 | row = dict(zip(header, line)) 66 | 67 | # Parse date. 68 | try: 69 | date = parse_date(row["date"]) 70 | except ValueError: 71 | skipped["date"] += 1 72 | continue 73 | day = (date - args.start_date).days 74 | 75 | # Parse location. 76 | location = row["country"] 77 | if location in ("", "?"): 78 | skipped["location"] += 1 79 | continue 80 | division = row["division"] 81 | if division not in ("", "?"): 82 | location += " / " + division 83 | 84 | # Parse lineage. 85 | lineage = row["pango_lineage"] 86 | if lineage in ("", "?", "unclassifiable"): 87 | skipped["lineage"] += 1 88 | continue 89 | assert lineage[0] in "ABCDEFGHIJKLMNOPQRSTUVWXYZ", lineage 90 | 91 | # Append row. 92 | columns["day"].append(day) 93 | columns["location"].append(location) 94 | columns["lineage"].append(lineage) 95 | 96 | # Record stats. 97 | stats["day"][day] += 1 98 | stats["location"][location] += 1 99 | stats["lineage"][lineage] += 1 100 | for aa in row["aaSubstitutions"].split(","): 101 | stats["aa"][aa] += 1 102 | stats["lineage_aa"][lineage, aa] += 1 103 | columns = dict(columns) 104 | stats = dict(stats) 105 | logger.info(f"kept {len(columns['location'])} rows") 106 | logger.info(f"skipped {sum(skipped.values())} due to:\n{dict(skipped)}") 107 | for k, v in stats.items(): 108 | logger.info(f"found {len(v)} {k}s") 109 | 110 | logger.info(f"saving {args.stats_file_out}") 111 | with open(args.stats_file_out, "wb") as f: 112 | pickle.dump(stats, f) 113 | 114 | logger.info(f"saving {args.columns_file_out}") 115 | with open(args.columns_file_out, "wb") as f: 116 | pickle.dump(columns, f) 117 | 118 | # Create contiguous coordinates. 119 | locations = sorted(stats["location"]) 120 | lineages = sorted(stats["lineage"]) 121 | aa_counts = Counter() 122 | for (lineage, aa), count in stats["lineage_aa"].most_common(): 123 | if count * 2 >= stats["lineage"][lineage]: 124 | aa_counts[aa] += count 125 | logger.info(f"kept {len(aa_counts)}/{len(stats['aa'])} aa substitutions") 126 | aa_mutations = [aa for aa, _ in aa_counts.most_common()] 127 | 128 | # Create a dense feature matrix. 129 | aa_features = torch.zeros(len(lineages), len(aa_mutations), dtype=torch.float) 130 | logger.info( 131 | f"saving {tuple(aa_features.shape)} features to {args.features_file_out}" 132 | ) 133 | for s, lineage in enumerate(lineages): 134 | for f, aa in enumerate(aa_mutations): 135 | count = stats["lineage_aa"].get((lineage, aa)) 136 | if count is None: 137 | continue 138 | aa_features[s, f] = count / stats["lineage"][lineage] 139 | features = { 140 | "lineages": lineages, 141 | "aa_mutations": aa_mutations, 142 | "aa_features": aa_features, 143 | } 144 | with open(args.features_file_out, "wb") as f: 145 | torch.save(features, f) 146 | 147 | # Create a dense dataset. 148 | locations, coarsen_location = coarsen_locations(args, stats["location"]) 149 | location_id = {location: i for i, location in enumerate(locations)} 150 | lineage_id = {lineage: i for i, lineage in enumerate(lineages)} 151 | T = max(stats["day"]) // args.time_step_days + 1 152 | P = len(locations) 153 | S = len(lineages) 154 | counts = torch.zeros(T, P, S) 155 | for day, location, lineage in zip( 156 | columns["day"], columns["location"], columns["lineage"] 157 | ): 158 | location = coarsen_location.get(location, location) 159 | t = day // args.time_step_days 160 | p = location_id[location] 161 | s = lineage_id[lineage] 162 | counts[t, p, s] += 1 163 | logger.info(f"counts data is {counts.ne(0).float().mean().item()*100:0.3g}% dense") 164 | sparse_counts = dense_to_sparse(counts) 165 | place_lineage_index = counts.ne(0).any(0).reshape(-1).nonzero(as_tuple=True)[0] 166 | logger.info(f"saving {tuple(counts.shape)} counts to {args.dataset_file_out}") 167 | dataset = { 168 | "start_date": args.start_date, 169 | "time_step_days": args.time_step_days, 170 | "locations": locations, 171 | "lineages": lineages, 172 | "mutations": aa_mutations, 173 | "features": aa_features, 174 | "weekly_counts": counts, 175 | "sparse_counts": sparse_counts, 176 | "place_lineage_index": place_lineage_index, 177 | } 178 | with open(args.dataset_file_out, "wb") as f: 179 | torch.save(dataset, f) 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser(description=__doc__) 184 | parser.add_argument( 185 | "--metadata-file-in", default="results/nextstrain/metadata.tsv.gz" 186 | ) 187 | parser.add_argument("--columns-file-out", default="results/nextstrain.columns.pkl") 188 | parser.add_argument("--stats-file-out", default="results/nextstrain.stats.pkl") 189 | parser.add_argument("--features-file-out", default="results/nextstrain.features.pt") 190 | parser.add_argument("--dataset-file-out", default="results/nextstrain.data.pt") 191 | parser.add_argument("--start-date", default=START_DATE) 192 | parser.add_argument("--time-step-days", default=14, type=int) 193 | parser.add_argument("--min-region-size", default=50, type=int) 194 | args = parser.parse_args() 195 | args.start_date = parse_date(args.start_date) 196 | main(args) 197 | -------------------------------------------------------------------------------- /scripts/pull_gisaid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh -ex 2 | 3 | # Ensure data directory (or a link) exists. 4 | test -e results || mkdir results 5 | 6 | # Download. 7 | curl -u $GISAID_USERNAME:$GISAID_PASSWORD --retry 4 \ 8 | https://www.epicov.org/epi3/3p/$GISAID_FEED/export/provision.json.xz \ 9 | > results/gisaid.json.xz 10 | 11 | # Decompress, keeping the original. 12 | xz -d -k -f -T0 -v results/gisaid.json.xz 13 | -------------------------------------------------------------------------------- /scripts/pull_nextstrain.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh +ex 2 | 3 | mkdir -p results/nextstrain 4 | 5 | curl https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz -o results/nextstrain/metadata.tsv.gz 6 | 7 | gunzip -kvf results/nextstrain/metadata.tsv.gz 8 | -------------------------------------------------------------------------------- /scripts/pull_usher.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh +ex 2 | 3 | mkdir -p results/usher 4 | 5 | # These are the available files: 6 | # public-latest.version.txt 7 | # public-latest.metadata.tsv.gz 8 | # public-latest.all.masked.pb.gz 9 | # public-latest.all.masked.vcf.gz 10 | # public-latest.all.nwk.gz 11 | url='http://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/UShER_SARS-CoV-2' 12 | for name in version.txt metadata.tsv.gz all.masked.pb.gz 13 | do 14 | curl $url/public-latest.$name -o results/usher/$name 15 | done 16 | 17 | gunzip -kvf results/usher/*.gz 18 | -------------------------------------------------------------------------------- /scripts/rank_mutations.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import functools 6 | import logging 7 | import math 8 | import os 9 | 10 | import torch 11 | from pyro import poutine 12 | 13 | from pyrocov import mutrans 14 | 15 | logger = logging.getLogger(__name__) 16 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO) 17 | 18 | 19 | def cached(filename): 20 | def decorator(fn): 21 | @functools.wraps(fn) 22 | def cached_fn(*args, **kwargs): 23 | f = filename(*args, **kwargs) if callable(filename) else filename 24 | if args[0].force or not os.path.exists(f): 25 | result = fn(*args, **kwargs) 26 | logger.info(f"saving {f}") 27 | torch.save(result, f) 28 | else: 29 | logger.info(f"loading cached {f}") 30 | result = torch.load(f, map_location=args[0].device) 31 | return result 32 | 33 | return cached_fn 34 | 35 | return decorator 36 | 37 | 38 | @cached("results/mutrans.data.pt") 39 | def load_data(args): 40 | return mutrans.load_gisaid_data(device=args.device) 41 | 42 | 43 | @cached("results/rank_mutations.rank_mf_svi.pt") 44 | def rank_mf_svi(args, dataset): 45 | result = mutrans.fit_mf_svi( 46 | dataset, 47 | mutrans.model, 48 | learning_rate=args.svi_learning_rate, 49 | num_steps=args.svi_num_steps, 50 | log_every=args.log_every, 51 | seed=args.seed, 52 | ) 53 | result["args"] = (args,) 54 | sigma = result["mean"] / result["std"] 55 | result["ranks"] = sigma.sort(0, descending=True).indices 56 | result["cond_data"] = { 57 | "feature_scale": result["median"]["feature_scale"].item(), 58 | "concentration": result["median"]["concentration"].item(), 59 | } 60 | del result["guide"] 61 | return result 62 | 63 | 64 | @cached("results/rank_mutations.rank_full_svi.pt") 65 | def rank_full_svi(args, dataset): 66 | result = mutrans.fit_full_svi( 67 | dataset, 68 | mutrans.model, 69 | learning_rate=args.full_learning_rate, 70 | learning_rate_decay=args.full_learning_rate_decay, 71 | num_steps=args.full_num_steps, 72 | log_every=args.log_every, 73 | seed=args.seed, 74 | ) 75 | result["args"] = (args,) 76 | result["mean"] = result["params"]["rate_coef_loc"] 77 | scale_tril = result["params"]["rate_coef_scale_tril"] 78 | result["cov"] = scale_tril @ scale_tril.T 79 | result["var"] = result["cov"].diag() 80 | result["std"] = result["var"].sqrt() 81 | sigma = result["mean"] / result["std"] 82 | result["ranks"] = sigma.sort(0, descending=True).indices 83 | result["cond_data"] = { 84 | "feature_scale": result["median"]["feature_scale"].item(), 85 | "concentration": result["median"]["concentration"].item(), 86 | } 87 | return result 88 | 89 | 90 | @cached("results/rank_mutations.hessian.pt") 91 | def compute_hessian(args, dataset, result): 92 | logger.info("Computing Hessian") 93 | features = dataset["features"] 94 | weekly_clades = dataset["weekly_clades"] 95 | rate_coef = result["median"]["rate_coef"].clone().requires_grad_() 96 | 97 | cond_data = result["median"].copy() 98 | cond_data.pop("rate") 99 | cond_data.pop("rate_coef") 100 | model = poutine.condition(mutrans.model, cond_data) 101 | 102 | def log_prob(rate_coef): 103 | with poutine.trace() as tr: 104 | with poutine.condition(data={"rate_coef": rate_coef}): 105 | model(weekly_clades, features) 106 | return tr.trace.log_prob_sum() 107 | 108 | hessian = torch.autograd.functional.hessian( 109 | log_prob, 110 | rate_coef, 111 | create_graph=False, 112 | strict=True, 113 | ) 114 | 115 | result = { 116 | "args": args, 117 | "mutations": dataset["mutations"], 118 | "initial_ranks": result, 119 | "mean": result["mean"], 120 | "hessian": hessian, 121 | } 122 | 123 | logger.info("Computing covariance") 124 | result["cov"] = _sym_inverse(-hessian) 125 | result["var"] = result["cov"].diag() 126 | result["std"] = result["var"].sqrt() 127 | sigma = result["mean"] / result["std"] 128 | result["ranks"] = sigma.sort(0, descending=True).indices 129 | return result 130 | 131 | 132 | def _sym_inverse(mat): 133 | eye = torch.eye(len(mat)) 134 | e = None 135 | for exponent in [-math.inf] + list(range(-20, 1)): 136 | eps = 10**exponent 137 | try: 138 | u = torch.cholesky(eye * eps + mat) 139 | except RuntimeError as e: # noqa F841 140 | continue 141 | logger.info(f"Added {eps:g} to Hessian diagonal") 142 | return torch.cholesky_inverse(u) 143 | raise e from None 144 | 145 | 146 | def _fit_map_filename(args, dataset, cond_data, guide=None, without_feature=None): 147 | return f"results/rank_mutations.{guide is None}.{without_feature}.pt" 148 | 149 | 150 | @cached(_fit_map_filename) 151 | def fit_map(args, dataset, cond_data, guide=None, without_feature=None): 152 | if without_feature is not None: 153 | # Drop feature. 154 | dataset = dataset.copy() 155 | dataset["features"] = dataset["features"].clone() 156 | dataset["features"][:, without_feature] = 0 157 | 158 | # Condition model. 159 | cond_data = {k: torch.as_tensor(v) for k, v in cond_data.items()} 160 | model = poutine.condition(mutrans.model, cond_data) 161 | 162 | # Fit. 163 | result = mutrans.fit_map( 164 | dataset, 165 | model, 166 | guide, 167 | learning_rate=args.map_learning_rate, 168 | num_steps=args.map_num_steps, 169 | log_every=args.log_every, 170 | seed=args.seed, 171 | ) 172 | 173 | result["args"] = args 174 | result["guide"] = guide 175 | if without_feature is None: 176 | result["mutation"] = None 177 | else: 178 | result["mutation"] = dataset["mutations"][without_feature] 179 | return result 180 | 181 | 182 | def rank_map(args, dataset, initial_ranks): 183 | """ 184 | Given an initial approximate ranking of features, compute MAP log 185 | likelihood ratios of the most significant features. 186 | """ 187 | # Fit an initial model for warm-starting. 188 | cond_data = initial_ranks["cond_data"] 189 | if args.warm_start: 190 | guide = fit_map(args, dataset, cond_data)["guide"] 191 | else: 192 | guide = None 193 | 194 | # Evaluate on the null hypothesis + the most positive features. 195 | dropouts = {} 196 | for feature in [None] + initial_ranks["ranks"].tolist(): 197 | dropouts[feature] = fit_map(args, dataset, cond_data, guide, feature) 198 | 199 | result = { 200 | "args": args, 201 | "mutations": dataset["mutations"], 202 | "initial_ranks": initial_ranks, 203 | "dropouts": dropouts, 204 | } 205 | logger.info("saving results/rank_mutations.pt") 206 | torch.save(result, "results/rank_mutations.pt") 207 | 208 | 209 | def main(args): 210 | if args.double: 211 | torch.set_default_dtype(torch.double) 212 | if args.cuda: 213 | torch.set_default_tensor_type( 214 | torch.cuda.DoubleTensor if args.double else torch.cuda.FloatTensor 215 | ) 216 | 217 | dataset = load_data(args) 218 | if args.full: 219 | initial_ranks = rank_full_svi(args, dataset) 220 | else: 221 | initial_ranks = rank_mf_svi(args, dataset) 222 | if args.hessian: 223 | compute_hessian(args, dataset, initial_ranks) 224 | if args.dropout: 225 | rank_map(args, dataset, initial_ranks) 226 | 227 | 228 | if __name__ == "__main__": 229 | parser = argparse.ArgumentParser( 230 | description="Rank mutations via SVI and leave-feature-out MAP" 231 | ) 232 | parser.add_argument("--full", action="store_true") 233 | parser.add_argument("--full-num-steps", default=10001, type=int) 234 | parser.add_argument("--full-learning-rate", default=0.01, type=float) 235 | parser.add_argument("--full-learning-rate-decay", default=0.01, type=float) 236 | parser.add_argument("--svi-num-steps", default=1001, type=int) 237 | parser.add_argument("--svi-learning-rate", default=0.05, type=float) 238 | parser.add_argument("--map-num-steps", default=1001, type=int) 239 | parser.add_argument("--map-learning-rate", default=0.05, type=float) 240 | parser.add_argument("--dropout", action="store_true") 241 | parser.add_argument("--hessian", action="store_true") 242 | parser.add_argument("--warm-start", action="store_true") 243 | parser.add_argument("--double", action="store_true", default=True) 244 | parser.add_argument("--single", action="store_false", dest="double") 245 | parser.add_argument( 246 | "--cuda", action="store_true", default=torch.cuda.is_available() 247 | ) 248 | parser.add_argument("--cpu", dest="cuda", action="store_false") 249 | parser.add_argument("--seed", default=20210319, type=int) 250 | parser.add_argument("-f", "--force", action="store_true") 251 | parser.add_argument("-l", "--log-every", default=100, type=int) 252 | args = parser.parse_args() 253 | args.device = "cuda" if args.cuda else "cpu" 254 | main(args) 255 | -------------------------------------------------------------------------------- /scripts/run_backtesting.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | backtesting_days=$(seq -s, 150 14 550) 4 | 5 | python mutrans.py --backtesting-max-day $backtesting_days --forecast-steps 12 6 | -------------------------------------------------------------------------------- /scripts/update_headers.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import argparse 5 | import glob 6 | import os 7 | import sys 8 | 9 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | blacklist = ["/build/", "/dist/", "/pyrocov/external/"] 11 | file_types = [ 12 | ("*.py", "# {}"), 13 | ("*.cpp", "// {}"), 14 | ] 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--check", action="store_true") 18 | args = parser.parse_args() 19 | dirty = [] 20 | 21 | for basename, comment in file_types: 22 | copyright_line = comment.format("Copyright Contributors to the Pyro-Cov project.\n") 23 | # See https://spdx.org/ids-how 24 | spdx_line = comment.format("SPDX-License-Identifier: Apache-2.0\n") 25 | 26 | filenames = glob.glob(os.path.join(root, "**", basename), recursive=True) 27 | filenames.sort() 28 | filenames = [ 29 | filename 30 | for filename in filenames 31 | if not any(word in filename for word in blacklist) 32 | ] 33 | for filename in filenames: 34 | with open(filename) as f: 35 | lines = f.readlines() 36 | 37 | # Ignore empty files like __init__.py 38 | if all(line.isspace() for line in lines): 39 | continue 40 | 41 | # Ensure first few line are copyright notices. 42 | changed = False 43 | lineno = 0 44 | if not lines[lineno].startswith(comment.format("Copyright")): 45 | lines.insert(lineno, copyright_line) 46 | changed = True 47 | lineno += 1 48 | while lines[lineno].startswith(comment.format("Copyright")): 49 | lineno += 1 50 | 51 | # Ensure next line is an SPDX short identifier. 52 | if not lines[lineno].startswith(comment.format("SPDX-License-Identifier")): 53 | lines.insert(lineno, spdx_line) 54 | changed = True 55 | lineno += 1 56 | 57 | # Ensure next line is blank. 58 | if not lines[lineno].isspace(): 59 | lines.insert(lineno, "\n") 60 | changed = True 61 | 62 | if not changed: 63 | continue 64 | 65 | if args.check: 66 | dirty.append(filename) 67 | continue 68 | 69 | with open(filename, "w") as f: 70 | f.write("".join(lines)) 71 | 72 | print("updated {}".format(filename[len(root) + 1 :])) 73 | 74 | if dirty: 75 | print("The following files need license headers:\n{}".format("\n".join(dirty))) 76 | print("Please run 'make license'") 77 | sys.exit(1) 78 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E741,E203,W503 4 | per_file_ignores = 5 | pyrocov/io.py:E226 6 | exclude = 7 | generate_epiToPublicAndDate.py 8 | 9 | [isort] 10 | profile = black 11 | skip_glob = .ipynb_checkpoints 12 | known_first_party = pyrocov 13 | known_third_party = opt_einsum, pyro, torch, torchvision 14 | 15 | [tool:pytest] 16 | filterwarnings = error 17 | ignore::PendingDeprecationWarning 18 | ignore::DeprecationWarning 19 | once::DeprecationWarning 20 | ignore:Failed to find .* pangolin aliases may be stale:RuntimeWarning 21 | 22 | [mypy] 23 | ignore_missing_imports = True 24 | allow_redefinition = True 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import re 5 | import sys 6 | 7 | from setuptools import find_packages, setup 8 | 9 | with open("pyrocov/__init__.py") as f: 10 | for line in f: 11 | match = re.match('^__version__ = "(.*)"$', line) 12 | if match: 13 | __version__ = match.group(1) 14 | break 15 | 16 | try: 17 | long_description = open("README.md", encoding="utf-8").read() 18 | except Exception as e: 19 | sys.stderr.write("Failed to read README.md: {}\n".format(e)) 20 | sys.stderr.flush() 21 | long_description = "" 22 | 23 | setup( 24 | name="pyrocov", 25 | version="0.1.0", 26 | description="Pyro tools for Sars-CoV-2 analysis", 27 | long_description=long_description, 28 | long_description_content_type="text/markdown", 29 | packages=find_packages(include=["pyrocov"]), 30 | url="http://pyro.ai", 31 | author="Pyro team at the Broad Institute of MIT and Harvard", 32 | author_email="fobermey@broadinstitute.org", 33 | install_requires=[ 34 | "biopython>=1.54", 35 | "pyro-ppl>=1.7", 36 | "geopy", 37 | "gpytorch", 38 | "scikit-learn", 39 | "umap-learn", 40 | "mappy", 41 | "protobuf>=3.12,<3.13", # pinned by usher 42 | "tqdm", 43 | "colorcet", 44 | ], 45 | extras_require={ 46 | "test": [ 47 | "black", 48 | "isort>=5.0", 49 | "flake8", 50 | "pytest>=5.0", 51 | "pytest-xdist", 52 | "mypy>=0.812", 53 | "types-protobuf", 54 | ], 55 | }, 56 | python_requires=">=3.6", 57 | keywords="pyro pytorch phylogenetic machine learning", 58 | license="Apache 2.0", 59 | classifiers=[ 60 | "Intended Audience :: Developers", 61 | "Intended Audience :: Science/Research", 62 | "License :: OSI Approved :: Apache Software License", 63 | "Operating System :: POSIX :: Linux", 64 | "Operating System :: MacOS :: MacOS X", 65 | "Programming Language :: Python :: 3.6", 66 | "Programming Language :: Python :: 3.7", 67 | ], 68 | ) 69 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-2019 Uber Technologies, Inc. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | import os 6 | 7 | # create log handler for tests 8 | level = logging.INFO if "CI" in os.environ else logging.DEBUG 9 | logging.basicConfig(format="%(levelname).1s \t %(message)s", level=level) 10 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import logging 5 | import os 6 | 7 | import pyro 8 | 9 | level = logging.INFO if "CI" in os.environ else logging.DEBUG 10 | logging.basicConfig(format="%(levelname).1s \t %(message)s", level=level) 11 | 12 | 13 | def pytest_runtest_setup(item): 14 | pyro.clear_param_store() 15 | pyro.enable_validation(True) 16 | pyro.set_rng_seed(0) 17 | -------------------------------------------------------------------------------- /test/test_distributions.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | import torch 6 | from pyro.distributions.testing.gof import auto_goodness_of_fit 7 | 8 | from pyrocov.distributions import SoftLaplace 9 | 10 | TEST_FAILURE_RATE = 1e-2 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "Dist, params", 15 | [ 16 | (SoftLaplace, {"loc": 0.0, "scale": 1.0}), 17 | (SoftLaplace, {"loc": 1.0, "scale": 1.0}), 18 | (SoftLaplace, {"loc": 0.0, "scale": 10.0}), 19 | ], 20 | ) 21 | def test_gof(Dist, params): 22 | num_samples = 50000 23 | d = Dist(**params) 24 | samples = d.sample(torch.Size([num_samples])) 25 | probs = d.log_prob(samples).exp() 26 | 27 | # Test each batch independently. 28 | probs = probs.reshape(num_samples, -1) 29 | samples = samples.reshape(probs.shape + d.event_shape) 30 | for b in range(probs.size(-1)): 31 | gof = auto_goodness_of_fit(samples[:, b], probs[:, b]) 32 | assert gof > TEST_FAILURE_RATE 33 | -------------------------------------------------------------------------------- /test/test_io.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import glob 5 | import os 6 | 7 | import pytest 8 | import torch 9 | from Bio import Phylo 10 | 11 | from pyrocov.io import read_alignment, read_nexus_trees, stack_nexus_trees 12 | 13 | ROOT = os.path.dirname(os.path.dirname(__file__)) 14 | FILENAME = os.path.join(ROOT, "data", "GTR4G_posterior.trees") 15 | 16 | 17 | @pytest.mark.skipif(not os.path.exists(FILENAME), reason="file unavailable") 18 | @pytest.mark.xfail(reason="Python <3.8 cannot .read() large files", run=False) 19 | def test_bio_phylo_parse(): 20 | trees = Phylo.parse(FILENAME, format="nexus") 21 | for tree in trees: 22 | print(tree.count_terminals()) 23 | 24 | 25 | @pytest.mark.skipif(not os.path.exists(FILENAME), reason="file unavailable") 26 | @pytest.mark.parametrize("processes", [0, 2]) 27 | def test_read_nexus_trees(processes): 28 | trees = read_nexus_trees(FILENAME, max_num_trees=5, processes=processes) 29 | trees = list(trees) 30 | assert len(trees) == 5 31 | for tree in trees: 32 | assert tree.count_terminals() == 772 33 | 34 | 35 | @pytest.mark.skipif(not os.path.exists(FILENAME), reason="file unavailable") 36 | @pytest.mark.parametrize("processes", [0, 2]) 37 | def test_stack_nexus_trees(processes): 38 | phylo = stack_nexus_trees(FILENAME, max_num_trees=5, processes=processes) 39 | assert phylo.batch_shape == (5,) 40 | 41 | 42 | @pytest.mark.parametrize("filename", glob.glob("data/treebase/DS*.nex")) 43 | def test_read_alignment(filename): 44 | probs = read_alignment(filename) 45 | assert probs.dim() == 3 46 | assert torch.isfinite(probs).all() 47 | -------------------------------------------------------------------------------- /test/test_markov_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pyro.distributions as dist 5 | import pytest 6 | import torch 7 | 8 | from pyrocov.markov_tree import MarkovTree, _interpolate_lmve, _mpm 9 | from pyrocov.phylo import Phylogeny 10 | 11 | 12 | def grad(output, inputs, **kwargs): 13 | if not output.requires_grad: 14 | return list(map(torch.zeros_like, inputs)) 15 | return torch.autograd.grad(output, inputs, **kwargs) 16 | 17 | 18 | @pytest.mark.parametrize("size", range(2, 10)) 19 | def test_mpm(size): 20 | matrix = torch.randn(size, size).exp() 21 | matrix /= matrix.sum(dim=-1, keepdim=True) # Make stochastic. 22 | matrix = (matrix + 4 * torch.eye(size)) / 5 # Make diagonally dominant. 23 | vector = torch.randn(size) 24 | 25 | for t in range(0, 10): 26 | expected = vector @ matrix.matrix_power(t) 27 | actual = _mpm(matrix, torch.tensor(float(t)), vector) 28 | 29 | assert actual.shape == expected.shape 30 | assert torch.allclose(actual, expected) 31 | 32 | 33 | @pytest.mark.parametrize("size", [2, 3, 4, 5]) 34 | @pytest.mark.parametrize("duration", [1, 2, 3, 4, 5]) 35 | def test_interpolate_lmve_smoke(size, duration): 36 | matrix = torch.randn(duration, size, size).exp() 37 | log_vector = torch.randn(size) 38 | t0 = -0.6 39 | while t0 < duration + 0.9: 40 | t1 = t0 + 0.2 41 | while t1 < duration + 0.9: 42 | actual = _interpolate_lmve( 43 | torch.tensor(t0), torch.tensor(t1), matrix, log_vector 44 | ) 45 | assert actual.shape == log_vector.shape 46 | t1 += 1 47 | t0 += 1 48 | 49 | 50 | @pytest.mark.parametrize("num_states", [3, 7]) 51 | @pytest.mark.parametrize("num_leaves", [4, 16, 17]) 52 | @pytest.mark.parametrize("duration", [1, 5]) 53 | @pytest.mark.parametrize("num_samples", [1, 2, 3]) 54 | def test_markov_tree_log_prob(num_samples, duration, num_leaves, num_states): 55 | phylo = Phylogeny.generate(num_leaves, num_samples=num_samples) 56 | phylo.times.mul_(duration * 0.25).add_(0.75 * duration) 57 | phylo.times.round_() # Required for naive-vs-likelihood agreement. 58 | 59 | leaf_state = dist.Categorical(torch.ones(num_states)).sample([num_leaves]) 60 | 61 | state_trans = torch.randn(duration, num_states, num_states).mul(0.1).exp() 62 | state_trans /= state_trans.sum(dim=-1, keepdim=True) 63 | state_trans += 4 * torch.eye(num_states) 64 | state_trans /= state_trans.sum(dim=-1, keepdim=True) 65 | state_trans.requires_grad_() 66 | 67 | dist1 = MarkovTree(phylo, state_trans, method="naive") 68 | dist2 = MarkovTree(phylo, state_trans, method="likelihood") 69 | 70 | logp1 = dist1.log_prob(leaf_state) 71 | logp2 = dist2.log_prob(leaf_state) 72 | assert torch.allclose(logp1, logp2) 73 | 74 | grad1 = grad(logp1.logsumexp(0), [state_trans], allow_unused=True)[0] 75 | grad2 = grad(logp2.logsumexp(0), [state_trans], allow_unused=True)[0] 76 | grad1 = grad1 - grad1.mean(dim=-1, keepdim=True) 77 | grad2 = grad2 - grad2.mean(dim=-1, keepdim=True) 78 | assert torch.allclose(grad1, grad2, rtol=1e-4, atol=1e-4) 79 | -------------------------------------------------------------------------------- /test/test_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pyro.distributions as dist 5 | import pytest 6 | import torch 7 | from torch.autograd import grad 8 | 9 | from pyrocov.ops import ( 10 | logistic_logsumexp, 11 | sparse_multinomial_likelihood, 12 | sparse_poisson_likelihood, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize("T,P,S", [(5, 6, 7)]) 17 | @pytest.mark.parametrize("backend", ["sequential"]) 18 | def test_logistic_logsumexp(T, P, S, backend): 19 | alpha = torch.randn(P, S, requires_grad=True) 20 | beta = torch.randn(P, S, requires_grad=True) 21 | delta = torch.randn(P, S, requires_grad=True) 22 | tau = torch.randn(T, P) 23 | 24 | expected = logistic_logsumexp(alpha, beta, delta, tau, backend="naive") 25 | actual = logistic_logsumexp(alpha, beta, delta, tau, backend=backend) 26 | assert torch.allclose(actual, expected) 27 | 28 | probe = torch.randn(expected.shape) 29 | expected_grads = grad((probe * expected).sum(), [alpha, beta, delta]) 30 | actual_grads = grad((probe * actual).sum(), [alpha, beta, delta]) 31 | for e, a, name in zip(expected_grads, actual_grads, ["alpha", "beta", "delta"]): 32 | assert torch.allclose(a, e), name 33 | 34 | 35 | @pytest.mark.parametrize("T,P,S", [(2, 3, 4), (5, 6, 7), (8, 9, 10)]) 36 | def test_sparse_poisson_likelihood(T, P, S): 37 | log_rate = torch.randn(T, P, S) 38 | d = dist.Poisson(log_rate.exp()) 39 | value = d.sample() 40 | assert 0.1 < (value == 0).float().mean() < 0.9, "weak test" 41 | expected = d.log_prob(value).sum() 42 | 43 | full_log_rate = log_rate.logsumexp(-1) 44 | nnz = value.nonzero(as_tuple=True) 45 | nonzero_value = value[nnz] 46 | nonzero_log_rate = log_rate[nnz] 47 | actual = sparse_poisson_likelihood(full_log_rate, nonzero_log_rate, nonzero_value) 48 | assert torch.allclose(actual, expected) 49 | 50 | 51 | @pytest.mark.parametrize("T,P,S", [(2, 3, 4), (5, 6, 7), (8, 9, 10)]) 52 | def test_sparse_multinomial_likelihood(T, P, S): 53 | logits = torch.randn(T, P, S) 54 | value = dist.Poisson(logits.exp()).sample() 55 | 56 | d = dist.Multinomial(logits=logits, validate_args=False) 57 | assert 0.1 < (value == 0).float().mean() < 0.9, "weak test" 58 | expected = d.log_prob(value).sum() 59 | 60 | logits = logits.log_softmax(-1) 61 | total_count = value.sum(-1) 62 | nnz = value.nonzero(as_tuple=True) 63 | nonzero_value = value[nnz] 64 | nonzero_logits = logits[nnz] 65 | actual = sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value) 66 | assert torch.allclose(actual, expected) 67 | -------------------------------------------------------------------------------- /test/test_phylo.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | import torch 6 | 7 | from pyrocov.phylo import Phylogeny 8 | 9 | 10 | @pytest.mark.parametrize("num_leaves", range(1, 50)) 11 | def test_smoke(num_leaves): 12 | phylo = Phylogeny.generate(num_leaves) 13 | phylo.num_lineages() 14 | phylo.hash_topology() 15 | phylo.time_mrca() 16 | 17 | 18 | @pytest.mark.parametrize("num_leaves", range(1, 10)) 19 | @pytest.mark.parametrize("num_samples", range(1, 5)) 20 | def test_smoke_batch(num_leaves, num_samples): 21 | phylo = Phylogeny.generate(num_leaves, num_samples=num_samples) 22 | phylo.num_lineages() 23 | phylo.hash_topology() 24 | phylo.time_mrca() 25 | 26 | 27 | def test_time_mrca(): 28 | # 0 0 29 | # 1 \ 1 30 | # /| 2 2 31 | # 3 4 |\ 3 32 | # 5 \ 4 33 | # / \ 6 5 34 | # 7 8 6 35 | times = torch.tensor([0.0, 1.0, 2.0, 3.0, 3.0, 4.0, 5.0, 6.0, 6.0]) 36 | parents = torch.tensor([-1, 0, 0, 1, 1, 2, 2, 5, 5]) 37 | leaves = torch.tensor([3, 4, 6, 7, 8]) 38 | phylo = Phylogeny(times, parents, leaves) 39 | 40 | actual = phylo.time_mrca() 41 | expected = torch.tensor( 42 | [ 43 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 44 | [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0], 45 | [0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0], 46 | [0.0, 1.0, 0.0, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0], 47 | [0.0, 1.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0], 48 | [0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 2.0, 4.0, 4.0], 49 | [0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 5.0, 2.0, 2.0], 50 | [0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 2.0, 6.0, 4.0], 51 | [0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 2.0, 4.0, 6.0], 52 | ] 53 | ) 54 | assert (actual == expected).all() 55 | 56 | actual = phylo.leaf_time_mrca() 57 | expected = torch.tensor( 58 | [ 59 | [3.0, 1.0, 0.0, 0.0, 0.0], 60 | [1.0, 3.0, 0.0, 0.0, 0.0], 61 | [0.0, 0.0, 5.0, 2.0, 2.0], 62 | [0.0, 0.0, 2.0, 6.0, 4.0], 63 | [0.0, 0.0, 2.0, 4.0, 6.0], 64 | ] 65 | ) 66 | assert (actual == expected).all() 67 | -------------------------------------------------------------------------------- /test/test_sarscov2.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | from pyrocov.sarscov2 import nuc_mutations_to_aa_mutations 5 | 6 | 7 | def test_nuc_to_aa(): 8 | assert nuc_mutations_to_aa_mutations(["A23403G"]) == ["S:D614G"] 9 | -------------------------------------------------------------------------------- /test/test_sketch.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import re 5 | 6 | import pyro.distributions as dist 7 | import pytest 8 | import torch 9 | 10 | from pyrocov.sketch import AMSSketcher, ClockSketcher, KmerCounter 11 | 12 | 13 | def random_string(size): 14 | probs = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.05]) 15 | probs /= probs.sum() 16 | string = "".join("ACGTN"[i] for i in dist.Categorical(probs).sample([size])) 17 | return string 18 | 19 | 20 | def test_kmer_counter(): 21 | string = random_string(10000) 22 | 23 | results = {} 24 | for backend in ["python", "cpp"]: 25 | results[backend] = KmerCounter(backend=backend) 26 | for part in re.findall("[ACTG]+", string): 27 | results[backend].update(part) 28 | results[backend].flush() 29 | 30 | expected = results["python"] 31 | actual = results["cpp"] 32 | assert actual == expected 33 | 34 | 35 | @pytest.mark.parametrize("min_k,max_k", [(2, 2), (2, 4), (3, 12)]) 36 | @pytest.mark.parametrize("bits", [16]) 37 | def test_string_to_soft_hash(min_k, max_k, bits): 38 | string = random_string(1000) 39 | 40 | results = {} 41 | for backend in ["python", "cpp"]: 42 | results[backend] = torch.empty(64) 43 | sketcher = AMSSketcher(min_k=min_k, max_k=max_k, bits=bits, backend=backend) 44 | sketcher.string_to_soft_hash(string, results[backend]) 45 | 46 | expected = results["python"] 47 | actual = results["cpp"] 48 | tol = expected.std().item() * 1e-6 49 | assert (actual - expected).abs().max().item() < tol 50 | 51 | 52 | @pytest.mark.parametrize("k", [2, 3, 4, 5, 8, 16, 32]) 53 | def test_string_to_clock_hash(k): 54 | string = random_string(1000) 55 | 56 | results = {} 57 | for backend in ["python", "cpp"]: 58 | sketcher = ClockSketcher(k, backend=backend) 59 | results[backend] = sketcher.init_sketch() 60 | sketcher.string_to_hash(string, results[backend]) 61 | 62 | expected = results["python"] 63 | actual = results["cpp"] 64 | assert (actual.clocks == expected.clocks).all() 65 | assert (actual.count == expected.count).all() 66 | 67 | 68 | @pytest.mark.parametrize("k", [2, 3, 4, 5, 8, 16, 32]) 69 | @pytest.mark.parametrize("size", [20000]) 70 | def test_clock_cdiff(k, size): 71 | n = 10 72 | strings = [random_string(size) for _ in range(n)] 73 | sketcher = ClockSketcher(k) 74 | sketch = sketcher.init_sketch(n) 75 | for i, string in enumerate(strings): 76 | sketcher.string_to_hash(string, sketch[i]) 77 | 78 | cdiff = sketcher.cdiff(sketch, sketch) 79 | assert cdiff.shape == (n, n) 80 | assert (cdiff.clocks.transpose(0, 1) == -cdiff.clocks).all() 81 | assert (cdiff.clocks.diagonal(dim1=0, dim2=1) == 0).all() 82 | assert (cdiff.count.diagonal(dim1=0, dim2=1) == 0).all() 83 | mask = torch.arange(n) < torch.arange(n).unsqueeze(-1) 84 | mean = cdiff.clocks[mask].abs().float().mean().item() 85 | assert mean > 64 - 10, mean 86 | -------------------------------------------------------------------------------- /test/test_softmax_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import math 5 | 6 | import pyro 7 | import pyro.distributions as dist 8 | import pytest 9 | import torch 10 | from pyro.infer import SVI, Trace_ELBO 11 | from pyro.optim import Adam 12 | 13 | from pyrocov.markov_tree import MarkovTree 14 | from pyrocov.phylo import Phylogeny 15 | from pyrocov.softmax_tree import SoftmaxTree 16 | 17 | 18 | @pytest.mark.parametrize("num_bits", [2, 3, 4, 5, 10]) 19 | @pytest.mark.parametrize("num_leaves", [2, 3, 4, 5, 10]) 20 | def test_rsample(num_leaves, num_bits): 21 | phylo = Phylogeny.generate(num_leaves) 22 | leaf_times = phylo.times[phylo.leaves] 23 | bit_times = torch.randn(num_bits) 24 | logits = torch.randn(num_leaves, num_bits) 25 | tree = SoftmaxTree(leaf_times, bit_times, logits) 26 | value = tree.rsample() 27 | tree.log_prob(value) 28 | 29 | 30 | def model(leaf_times, leaf_states, num_features): 31 | assert len(leaf_times) == len(leaf_states) 32 | 33 | # Timed tree concerning reproductive behavior only. 34 | coal_params = pyro.sample("coal_params", dist.CoalParamPrior("TODO")) # global 35 | # Note this is where our coalescent model assumes geographically 36 | # homogeneous reproductive rate, which is not realistic. 37 | # See appendix of (Vaughan et al. 2014) for discussion of this assumption. 38 | phylogeny = pyro.sample("phylogeny", dist.Coalescent(coal_params, leaf_times)) 39 | 40 | # This is compatible with subsampling features, but not leaves. 41 | subs_params = pyro.sample("subs_params", dist.GTRGamma("TODO")) # global 42 | with pyro.plate("features", num_features, leaf_states.size(-1)): 43 | # This is similar to the phylogeographic likelihood in the pyro-cov repo. 44 | # This is simpler (because it is time-homogeneous) 45 | # but more complex in that it is batched. 46 | # This computes mutation likelihood via dynamic programming. 47 | pyro.sample("leaf_times", MarkovTree(phylogeny, subs_params), obs=leaf_states) 48 | 49 | 50 | def guide(leaf_times, leaf_states, num_features, *, logits_fn=None): 51 | assert len(leaf_times) == len(leaf_states) 52 | 53 | # Sample all continuous latents in a giant correlated auxiliary. 54 | aux = pyro.sample("aux", dist.LowRankMultivariateNormal("TODO")) 55 | # Split it up (TODO switch to EasyGuide). 56 | pyro.sample("coal_params", dist.Delta(aux["TODO"])) # global 57 | pyro.sample("subs_params", dist.Delta(aux["TODO"])) # global 58 | # These are the times of each bit in the embedding vector. 59 | bit_times = pyro.sample( 60 | "bit_times", dist.Delta(aux["TODO"]), infer={"is_auxiliary": True} 61 | ) 62 | 63 | # Learn parameters of the discrete distributions, 64 | # possibly conditioned on continuous latents. 65 | if logits_fn is not None: 66 | # Amortized guide, compatible with subsampling leaves but not features. 67 | logits = logits_fn(leaf_states, leaf_times) # batched over leaves 68 | else: 69 | # Fully local guide, compatible with subsampling features but not leaves. 70 | with pyro.plate("leaves", len(leaf_times)): 71 | logits = pyro.param( 72 | "logits", lambda: torch.randn(leaf_times.shape), event_dim=0 73 | ) 74 | assert len(logits) == len(leaf_times) 75 | 76 | pyro.sample("phylogeny", SoftmaxTree(bit_times, logits)) 77 | 78 | 79 | @pytest.mark.xfail(reason="WIP") 80 | @pytest.mark.parametrize("num_features", [4]) 81 | @pytest.mark.parametrize("num_leaves", [2, 3, 4, 5, 10, 100]) 82 | def test_svi(num_leaves, num_features): 83 | phylo = Phylogeny.generate(num_leaves) 84 | leaf_times = phylo.times[phylo.leaves] 85 | leaf_states = torch.full((num_leaves, num_features), 0.5).bernoulli() 86 | 87 | svi = SVI(model, guide, Adam({"lr": 1e-4}), Trace_ELBO()) 88 | for i in range(2): 89 | loss = svi.step(leaf_times, leaf_states) 90 | assert math.isfinite(loss) 91 | -------------------------------------------------------------------------------- /test/test_strains.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pytest 5 | 6 | from pyrocov.strains import TimeSpaceStrainModel, simulate 7 | 8 | 9 | @pytest.mark.parametrize("T", [32]) 10 | @pytest.mark.parametrize("R", [6]) 11 | @pytest.mark.parametrize("S", [5]) 12 | def test_strains(T, R, S): 13 | dataset = simulate(T, R, S) 14 | model = TimeSpaceStrainModel(**dataset) 15 | model.fit(num_steps=101, haar=False) 16 | model.median() 17 | -------------------------------------------------------------------------------- /test/test_substitution.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import pyro.poutine as poutine 5 | import pytest 6 | import torch 7 | from pyro.infer.autoguide import AutoDelta 8 | 9 | from pyrocov.substitution import GeneralizedTimeReversible, JukesCantor69 10 | 11 | 12 | @pytest.mark.parametrize("Model", [JukesCantor69, GeneralizedTimeReversible]) 13 | def test_matrix_exp(Model): 14 | model = Model() 15 | guide = AutoDelta(model) 16 | guide() 17 | trace = poutine.trace(guide).get_trace() 18 | t = torch.randn(10).exp() 19 | with poutine.replay(trace=trace): 20 | m = model() 21 | assert torch.allclose(model(), m) 22 | 23 | exp_mt = (m * t[:, None, None]).matrix_exp() 24 | actual = model.matrix_exp(t) 25 | assert torch.allclose(actual, exp_mt, atol=1e-6) 26 | 27 | actual = model.log_matrix_exp(t) 28 | log_exp_mt = exp_mt.log() 29 | assert torch.allclose(actual, log_exp_mt, atol=1e-6) 30 | -------------------------------------------------------------------------------- /test/test_usher.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro-Cov project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import math 5 | import os 6 | import random 7 | import tempfile 8 | from collections import defaultdict 9 | 10 | from pyrocov.align import PANGOLEARN_DATA 11 | from pyrocov.usher import prune_mutation_tree, refine_mutation_tree 12 | 13 | 14 | def test_refine_prune(): 15 | filename1 = os.path.join(PANGOLEARN_DATA, "lineageTree.pb") 16 | with tempfile.TemporaryDirectory() as tmpdirname: 17 | filename2 = os.path.join(tmpdirname, "refinedTree.pb") 18 | filename3 = os.path.join(tmpdirname, "prunedTree.pb") 19 | 20 | # Refine the tree. 21 | fine_to_coarse = refine_mutation_tree(filename1, filename2) 22 | 23 | # Find canonical fine names for each coarse name. 24 | coarse_to_fine = defaultdict(list) 25 | for fine, coarse in fine_to_coarse.items(): 26 | coarse_to_fine[coarse].append(fine) 27 | coarse_to_fine = {k: min(vs) for k, vs in coarse_to_fine.items()} 28 | 29 | # Prune the tree, keeping coarse nodes. 30 | weights = {fine: random.lognormvariate(0, 1) for fine in fine_to_coarse} 31 | for fine in coarse_to_fine.values(): 32 | weights[fine] = math.inf 33 | prune_mutation_tree(filename2, filename3, weights=weights, max_num_nodes=10000) 34 | --------------------------------------------------------------------------------