├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── Makefile
├── README.md
├── generate_epiToPublicAndDate.ipynb
├── generate_epiToPublicAndDate.py
├── notebooks
├── NS_ratio.ipynb
├── characterize_predicted.ipynb
├── contrast-strains.ipynb
├── dissect_delta_logR.ipynb
├── escape_calc.ipynb
├── explore-gisaid-af.ipynb
├── explore-gisaid.ipynb
├── explore-jhu-time-series.ipynb
├── explore-mutation-annotated-tree.ipynb
├── explore-nextclade-counts.ipynb
├── explore-nextclade.ipynb
├── explore-nextstrain-open-data.ipynb
├── explore-owid-vaccinations.ipynb
├── explore-usher.ipynb
├── explore_pairs_of_lineages.ipynb
├── explore_pango_lineages.ipynb
├── grid_search.ipynb
├── ind_emergences_violin.ipynb
├── make_mutations_subset.ipynb
├── mini_mutrans.ipynb
├── mutrans-vary-coef-scale.ipynb
├── mutrans.ipynb
├── mutrans_backtesting-fullrun.ipynb
├── mutrans_backtesting.ipynb
├── mutrans_forecasts.ipynb
├── mutrans_gene.ipynb
├── mutrans_hack_7_2022.ipynb
├── mutrans_loo.ipynb
├── paper
├── preprocess-pairwise-allele-fractions.ipynb
├── rank_mutations.ipynb
├── results
└── slop_mapping.ipynb
├── paper
├── .gitignore
├── BA.1_logistic_regression_all_US.jpg
├── Makefile
├── N_S_peak_ratio.png
├── N_S_ratio.png
├── README.md
├── Science.bst
├── accession_ids.txt.xz
├── backtesting
│ ├── L1_error_barplot_England.png
│ ├── L1_error_barplot_all.png
│ ├── L1_error_barplot_other.png
│ ├── L1_error_barplot_top100-1000.png
│ ├── L1_error_barplot_top100.png
│ ├── backtesting_day_248_early_prediction_england.png
│ ├── backtesting_day_346_early_prediction_england.png
│ ├── backtesting_day_360_early_prediction_england.png
│ ├── backtesting_day_430.png
│ ├── backtesting_day_444.png
│ ├── backtesting_day_458.png
│ ├── backtesting_day_472.png
│ ├── backtesting_day_486.png
│ ├── backtesting_day_500.png
│ ├── backtesting_day_514.png
│ ├── backtesting_day_514_early_prediction_england.png
│ ├── backtesting_day_528.png
│ ├── backtesting_day_528_early_prediction_england.png
│ ├── backtesting_day_542.png
│ ├── backtesting_day_556.png
│ ├── backtesting_day_570.png
│ ├── backtesting_day_584.png
│ ├── backtesting_day_598.png
│ ├── backtesting_day_612.png
│ ├── backtesting_day_612_early_prediction_england.png
│ ├── backtesting_day_626.png
│ ├── backtesting_day_640.png
│ ├── backtesting_day_640_early_prediction_england.png
│ ├── backtesting_day_654.png
│ ├── backtesting_day_668.png
│ ├── backtesting_day_682.png
│ ├── backtesting_day_696.png
│ ├── backtesting_day_710.png
│ ├── backtesting_day_710_early_prediction_england.png
│ ├── backtesting_day_724.png
│ ├── backtesting_day_724_early_prediction_england.png
│ ├── backtesting_day_738.png
│ ├── backtesting_day_738_early_prediction_england.png
│ ├── backtesting_day_752.png
│ ├── backtesting_day_752_early_prediction_england.png
│ ├── backtesting_day_766.png
│ ├── backtesting_day_766_early_prediction_Asia_Europe_Africa.png
│ ├── backtesting_day_766_early_prediction_Massachusetts_for_heatmap.png
│ ├── backtesting_day_766_early_prediction_Massachusetts_for_heatmap_nolegend.png
│ ├── backtesting_day_766_early_prediction_USA_France_England_Brazil_Australia_Russia.png
│ ├── backtesting_day_766_early_prediction_brazil_for_heatmap.png
│ ├── backtesting_day_766_early_prediction_brazil_for_heatmap_nolegend.png
│ ├── backtesting_day_766_early_prediction_england.png
│ ├── backtesting_day_766_early_prediction_england_for_heatmap.png
│ └── backtesting_day_766_early_prediction_england_for_heatmap_nolegend.png
├── binding_retained.png
├── clade_distribution.png
├── coef_scale_manhattan.png
├── coef_scale_pearson.png
├── coef_scale_scatter.png
├── coef_scale_volcano.png
├── convergence.png
├── convergence_loss.png
├── convergent_evolution.jpg
├── deep_scanning.png
├── delta_logR_breakdown.png
├── dobanno_2021_immunogenicity_N.png
├── forecast.png
├── formatted
│ ├── Figure1.ai
│ ├── Figure1.jpg
│ ├── Figure2.ai
│ ├── Figure2.jpg
│ ├── Figure2_alt.ai
│ └── Figure2_alt.jpg
├── gene_ratios.png
├── ind_emergences_violin.png
├── lineage_agreement.png
├── lineage_heterogeneity.png
├── lineage_prediction.png
├── main.bib
├── main.md
├── manhattan.png
├── manhattan_N.png
├── manhattan_N_coef_scale_0.05.png
├── manhattan_ORF1a.png
├── manhattan_ORF1a_coef_scale_0.05.png
├── manhattan_ORF1b.png
├── manhattan_ORF1b_coef_scale_0.05.png
├── manhattan_ORF3a.png
├── manhattan_ORF3a_coef_scale_0.05.png
├── manhattan_S.png
├── manhattan_S_coef_scale_0.05.png
├── moran.csv
├── multinomial_LR_vs_pyr0.jpg
├── mutation_agreement.png
├── mutation_europe_boxplot_rankby_s.png
├── mutation_europe_boxplot_rankby_t.png
├── mutation_scoring
│ ├── by_gene_heatmap.png
│ ├── plot_individual_mutation_significance.png
│ ├── plot_individual_mutation_significance_M.png
│ ├── plot_individual_mutation_significance_S.png
│ ├── plot_individual_mutation_significance_S__K_to_N.png
│ ├── plot_k_to_n_mutation_significance.png
│ ├── plot_per_gene_aggregate_sign.png
│ ├── plot_pval_vs_top_genes_E.png
│ ├── plot_pval_vs_top_genes_N.png
│ ├── plot_pval_vs_top_genes_ORF10.png
│ ├── plot_pval_vs_top_genes_ORF1a.png
│ ├── plot_pval_vs_top_genes_ORF1b.png
│ ├── plot_pval_vs_top_genes_ORF3a.png
│ ├── plot_pval_vs_top_genes_ORF6.png
│ ├── plot_pval_vs_top_genes_ORF7b.png
│ ├── plot_pval_vs_top_genes_ORF8.png
│ ├── plot_pval_vs_top_genes_ORF9b.png
│ ├── plot_pval_vs_top_genes_S.png
│ ├── pvals_vs_top_genes.png
│ └── top_substitutions.png
├── mutation_summaries.jpg
├── mutations.tsv
├── region_distribution.png
├── relative_growth_rate.jpg
├── schematic_overview.key
├── schematic_overview.png
├── scibib.bib
├── scicite.sty
├── scifile.tex
├── spectrum_transmissibility.jpg
├── strain_emergence.png
├── strain_europe_boxplot.png
├── strain_prevalence.png
├── strain_prevalence_data.csv
├── strains.tsv
├── supplement.tex
├── supplement_table_3_02_05_2022.jpeg
├── table2.jpeg
├── table_1.jpeg
├── table_1.tsv
├── table_1_02_05_2022.jpeg
├── table_2_02_05_2022.jpeg
├── vary_gene_elbo.png
├── vary_gene_likelihood.png
├── vary_gene_loss.png
├── vary_nsp_elbo.png
├── vary_nsp_likelihood.png
├── volcano.csv
└── volcano.png
├── pyrocov
├── __init__.py
├── aa.py
├── align.py
├── distributions.py
├── external
│ ├── __init__.py
│ └── usher
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── __init__.py
│ │ └── parsimony_pb2.py
├── fasta.py
├── geo.py
├── growth.py
├── hashsubset.py
├── io.py
├── markov_tree.py
├── mutrans.py
├── mutrans_helpers.py
├── ops.py
├── pangolin.py
├── phylo.py
├── plot_additional_figs.R
├── plotting.py
├── sarscov2.py
├── sketch.cpp
├── sketch.py
├── softmax_tree.py
├── special.py
├── stats.py
├── strains.py
├── substitution.py
├── usher.py
└── util.py
├── scripts
├── analyze_nextstrain.py
├── fix_columns.py
├── git_pull.py
├── moran.py
├── mutrans.py
├── preprocess_credits.py
├── preprocess_gisaid.py
├── preprocess_nextstrain.py
├── preprocess_usher.py
├── pull_gisaid.sh
├── pull_nextstrain.sh
├── pull_usher.sh
├── rank_mutations.py
├── run_backtesting.sh
└── update_headers.py
├── setup.cfg
├── setup.py
└── test
├── __init__.py
├── conftest.py
├── test_distributions.py
├── test_io.py
├── test_markov_tree.py
├── test_ops.py
├── test_phylo.py
├── test_sarscov2.py
├── test_sketch.py
├── test_softmax_tree.py
├── test_strains.py
├── test_substitution.py
└── test_usher.py
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | # Allows you to run this workflow manually from the Actions tab
10 | workflow_dispatch:
11 |
12 | env:
13 | CXX: g++-8
14 | CC: gcc-8
15 |
16 | jobs:
17 | lint:
18 | runs-on: ubuntu-20.04
19 | strategy:
20 | matrix:
21 | python-version: [3.7]
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python ${{ matrix.python-version }}
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip wheel setuptools
31 | pip install flake8 black isort>=5.0 mypy types-protobuf
32 | - name: Lint
33 | run: |
34 | make lint
35 | unit:
36 | runs-on: ubuntu-20.04
37 | needs: lint
38 | strategy:
39 | matrix:
40 | python-version: [3.7]
41 | steps:
42 | - uses: actions/checkout@v2
43 | - name: Set up Python ${{ matrix.python-version }}
44 | uses: actions/setup-python@v2
45 | with:
46 | python-version: ${{ matrix.python-version }}
47 | - name: Install dependencies
48 | run: |
49 | sudo add-apt-repository -y ppa:ubuntu-toolchain-r/test
50 | sudo apt-get update
51 | sudo apt-get install gcc-8 g++-8 ninja-build
52 | python -m pip install --upgrade pip wheel setuptools
53 | pip install torch==1.11.0+cpu torchvision==0.12.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
54 | pip install .[test]
55 | pip install coveralls
56 | pip freeze
57 | - name: Run unit tests
58 | run: |
59 | make test
60 |
61 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | build/
2 | nextalign
3 | nextclade
4 | nextclade_results/
5 | /data
6 | /results
7 | /results.*
8 | *.swp
9 | temp.*
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | pip-wheel-metadata/
34 | share/python-wheels/
35 | *.egg-info/
36 | .installed.cfg
37 | *.egg
38 | MANIFEST
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .nox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | *.py,cover
61 | .hypothesis/
62 | .pytest_cache/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # Django stuff:
69 | *.log
70 | local_settings.py
71 | db.sqlite3
72 | db.sqlite3-journal
73 |
74 | # Flask stuff:
75 | instance/
76 | .webassets-cache
77 |
78 | # Scrapy stuff:
79 | .scrapy
80 |
81 | # Sphinx documentation
82 | docs/_build/
83 |
84 | # PyBuilder
85 | target/
86 |
87 | # Jupyter Notebook
88 | .ipynb_checkpoints
89 |
90 | # IPython
91 | profile_default/
92 | ipython_config.py
93 |
94 | # pyenv
95 | .python-version
96 |
97 | # pipenv
98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
101 | # install all needed dependencies.
102 | #Pipfile.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | ## File organization
2 |
3 | - The root directory `/` contains scripts and jupyter notebooks.
4 | - All module code lives in the `/pyrocov/` package directory.
5 | As a rule of thumb, if you're importing a .py file an any other file,
6 | it should live in the `/pyrocov/` directory; if your .py file has an `if __name__ == "__main__"` check, then it should live in the root directory.
7 | - Notebooks should be cleared of data before committing.
8 | - The `/results/` directory (not under git control) contains large intermediate data
9 | including preprocessed input data and outputs from models and prediction.
10 | - The `/paper/` directory contains git-controlled output for the paper, namely
11 | plots, .tsv files, and .fasta files for sharing outside of the dev team.
12 |
13 | ## Committing code
14 |
15 | - The current policy is to allow pushing of code directly, without PRs.
16 | - If you would like code reviewed, please submit a PR and tag a reviewer.
17 | - Please run `make format` and `make lint` before committing code.
18 | - We'd like to resurrect `make test` but it is currently broken.
19 | - Each notebook has an owner (see git history);
20 | changes to notebooks by non-owners should be through pull request.
21 | This rule is intended to reduce difficult merge conflicts in jupyter notebooks.
22 |
23 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL := /bin/bash
2 |
3 | ###########################################################################
4 | # installation
5 |
6 | install: FORCE
7 | pip install -e .[test]
8 |
9 | install-pangolin:
10 | conda install -y -c bioconda -c conda-forge -c defaults pangolin
11 |
12 | install-usher:
13 | conda install -y -c bioconda -c conda-forge -c defaults usher
14 |
15 | ###########################################################################
16 | # ci tasks
17 |
18 | lint: FORCE
19 | flake8 --extend-exclude=pyrocov/external
20 | black --extend-exclude='notebooks|pyrocov/external' --check .
21 | isort --check --skip=pyrocov/external .
22 | python scripts/update_headers.py --check
23 | mypy . --exclude=build/
24 |
25 | format: FORCE
26 | black --extend-exclude='notebooks|pyrocov/external' .
27 | isort --skip=pyrocov/external .
28 | python scripts/update_headers.py
29 |
30 | test: lint FORCE
31 | python scripts/git_pull.py --no-update cov-lineages/pango-designation
32 | python scripts/git_pull.py --no-update cov-lineages/pangoLEARN
33 | python scripts/git_pull.py --no-update nextstrain/nextclade
34 | pytest -v -n auto test
35 | test -e results/columns.3000.pkl \
36 | && python scripts/mutrans.py --test -n 2 -s 2 \
37 | || echo skipping test
38 |
39 | ###########################################################################
40 | # Main processing workflows
41 |
42 | # The DO_NOT_UPDATE logic aims to avoid someone accidentally updating a frozen
43 | # results directory.
44 | update: FORCE
45 | ! test -f results/DO_NOT_UPDATE
46 | scripts/pull_nextstrain.sh
47 | scripts/pull_usher.sh
48 | python scripts/git_pull.py cov-lineages/pango-designation
49 | python scripts/git_pull.py cov-lineages/pangoLEARN
50 | python scripts/git_pull.py CSSEGISandData/COVID-19
51 | python scripts/git_pull.py nextstrain/nextclade
52 | echo "frozen" > results/DO_NOT_UPDATE
53 |
54 | preprocess: FORCE
55 | python scripts/preprocess_usher.py
56 |
57 | preprocess-gisaid: FORCE
58 | python scripts/preprocess_usher.py \
59 | --tree-file-in results/gisaid/gisaidAndPublic.masked.pb.gz \
60 | --gisaid-metadata-file-in results/gisaid/metadata_2022_*_*.tsv.gz
61 |
62 | analyze: FORCE
63 | python scripts/mutrans.py --vary-holdout
64 | python scripts/mutrans.py --vary-gene
65 | python scripts/mutrans.py --vary-nsp
66 | python scripts/mutrans.py --vary-leaves=9999 --num-steps=2001
67 |
68 | backtesting-piecewise: FORCE
69 | # Generates all the backtesting models piece by piece so that it can be run on a GPU enabled machine
70 | python scripts/mutrans.py --backtesting-max-day `seq -s, 150 14 220` --forecast-steps 12
71 | python scripts/mutrans.py --backtesting-max-day `seq -s, 220 14 500` --forecast-steps 12
72 | python scripts/mutrans.py --backtesting-max-day `seq -s, 500 14 625` --forecast-steps 12
73 | python scripts/mutrans.py --backtesting-max-day `seq -s, 626 14 700` --forecast-steps 12
74 | python scripts/mutrans.py --backtesting-max-day `seq -s, 710 14 766` --forecast-steps 12
75 |
76 | backtesting-nofeatures: FORCE
77 | # Generates all the backtesting models piece by piece so that it can be run on a GPU enabled machine
78 | python scripts/mutrans.py --backtesting-max-day `seq -s, 150 14 220` --forecast-steps 12 --model-type reparam-localinit-nofeatures
79 | python scripts/mutrans.py --backtesting-max-day `seq -s, 220 14 500` --forecast-steps 12 --model-type reparam-localinit-nofeatures
80 | python scripts/mutrans.py --backtesting-max-day `seq -s, 500 14 625` --forecast-steps 12 --model-type reparam-localinit-nofeatures
81 | python scripts/mutrans.py --backtesting-max-day `seq -s, 626 14 700` --forecast-steps 12 --model-type reparam-localinit-nofeatures
82 | python scripts/mutrans.py --backtesting-max-day `seq -s, 710 14 766` --forecast-steps 12 --model-type reparam-localinit-nofeatures
83 |
84 | backtesting-complete: FORCE
85 | # Run only after running backtesting-piecewise on a machine with > 500GB ram to aggregate results
86 | python scripts/mutrans.py --backtesting-max-day `seq -s, 150 14 766` --forecast-steps 12
87 |
88 | backtesting: FORCE
89 | # Maximum possible run in a GPU highmem machine
90 | python scripts/mutrans.py --backtesting-max-day `seq -s, 430 14 766` --forecast-steps 12
91 |
92 | backtesting-short: FORCE
93 | # For quick testing of backtesting code changes
94 | python scripts/mutrans.py --backtesting-max-day `seq -s, 500 14 700` --forecast-steps 12
95 |
96 | EXCLUDE='.*\.json$$|.*mutrans\.pt$$|.*temp\..*|.*\.[EI](gene|region)=.*\.pt$$|.*__(gene|region|lineage)__.*\.pt$$'
97 |
98 | push: FORCE
99 | gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M \
100 | rsync -r -x $(EXCLUDE) \
101 | $(shell readlink results)/ \
102 | gs://pyro-cov/$(shell readlink results | grep -o 'results\.[-0-9]*')/
103 |
104 | pull: FORCE
105 | gsutil -m rsync -r -x $(EXCLUDE) \
106 | gs://pyro-cov/$(shell readlink results | grep -o 'results\.[-0-9]*')/ \
107 | $(shell readlink results)/
108 |
109 | ###########################################################################
110 | # TODO remove these user-specific targets
111 |
112 | data:
113 | ln -sf ~/Google\ Drive\ File\ Stream/Shared\ drives/Pyro\ CoV data
114 |
115 | ssh-cpu:
116 | gcloud compute ssh --project pyro-284215 --zone us-central1-c pyro-cov-fritzo-vm -- -AX -t 'cd pyro-cov ; bash --login'
117 |
118 | ssh-gpu:
119 | gcloud compute ssh --project pyro-284215 --zone us-central1-c pyro-fritzo-vm-gpu -- -AX -t 'cd pyro-cov ; bash --login'
120 |
121 | FORCE:
122 |
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://github.com/broadinstitute/pyro-cov/releases)
2 | [](https://zenodo.org/badge/latestdoi/292037402)
3 |
4 | # Pyro models for SARS-CoV-2 analysis
5 |
6 | 
7 |
8 | Supporting material for the paper ["Analysis of 6.4 million SARS-CoV-2 genomes identifies mutations associated with fitness"](https://www.science.org/doi/10.1126/science.abm1208) ([medRxiv](https://www.medrxiv.org/content/10.1101/2021.09.07.21263228v2)). Figures and supplementary data for that paper are in the [paper/](paper/) directory.
9 |
10 | This is open source, but we are not intending to support code for use by outside groups. To use outputs of this model, we recommend ingesting the tables [strains.tsv](paper/strains.tsv) and [mutations.tsv](paper/mutations.tsv).
11 |
12 | ## Reproducing
13 |
14 | ### Install software
15 |
16 | Clone this repository:
17 | ```sh
18 | git clone git@github.com:broadinstitute/pyro-cov
19 | cd pyro-cov
20 | ```
21 |
22 | Install this python package:
23 | ```py
24 | pip install -e .
25 | ```
26 |
27 | ### Get access to GISAID data
28 |
29 | Work with GISAID to get a data agreement.
30 | Define the following environment variables:
31 | ```
32 | GISAID_USERNAME
33 | GISAID_PASSWORD
34 | GISAID_FEED
35 | ```
36 | For example my username is `fritz` and my gisaid feed is `broad2`.
37 |
38 | ### Download data
39 | This downloads data from GISAID and clones repos for other data sources.
40 | ```sh
41 | make update
42 | ```
43 |
44 | ### Preprocess data
45 |
46 | This takes under an hour.
47 | Results are cached in the `results/` directory, so re-running on newly pulled data should be able to re-use alignment and PANGOlineage classification work.
48 | ```sh
49 | make preprocess
50 | ```
51 |
52 | ### Analyze data
53 | ```sh
54 | make analyze
55 | ```
56 |
57 | ### Generate plots and tables
58 | Plots and tables are generated by running various notebooks:
59 | - [mutrans.py](notebooks/mutrans.ipynb)
60 | - [mutrans_backtesting.py](notebooks/mutrans_backtesting.ipynb)
61 | - [mutrans_gene.ipynb](notebooks/mutrans_gene.ipynb)
62 |
63 | ## Citing
64 |
65 | If you use this software or predictions in the [paper](paper/) directory please consider citing:
66 |
67 | ```
68 | @article {Obermeyer2021.09.07.21263228,
69 | author = {Obermeyer, Fritz and
70 | Schaffner, Stephen F. and
71 | Jankowiak, Martin and
72 | Barkas, Nikolaos and
73 | Pyle, Jesse D. and
74 | Park, Daniel J. and
75 | MacInnis, Bronwyn L. and
76 | Luban, Jeremy and
77 | Sabeti, Pardis C. and
78 | Lemieux, Jacob E.},
79 | title = {Analysis of 2.1 million SARS-CoV-2 genomes identifies mutations associated with transmissibility},
80 | elocation-id = {2021.09.07.21263228},
81 | year = {2021},
82 | doi = {10.1101/2021.09.07.21263228},
83 | publisher = {Cold Spring Harbor Laboratory Press},
84 | URL = {https://www.medrxiv.org/content/early/2021/09/13/2021.09.07.21263228},
85 | eprint = {https://www.medrxiv.org/content/early/2021/09/13/2021.09.07.21263228.full.pdf},
86 | journal = {medRxiv}
87 | }
88 | ```
89 |
--------------------------------------------------------------------------------
/generate_epiToPublicAndDate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "2f544c93-83ef-455f-b3d0-b507d6440376",
6 | "metadata": {},
7 | "source": [
8 | "Preprocess files for parsing"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 4,
14 | "id": "7dc48887-b123-43bb-9994-321a4c47abee",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import pandas as pd"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 6,
24 | "id": "5352238b-51d6-456f-99c7-1e125e2099b2",
25 | "metadata": {},
26 | "outputs": [
27 | {
28 | "name": "stderr",
29 | "output_type": "stream",
30 | "text": [
31 | "/tmp/ipykernel_2032/3544606020.py:1: DtypeWarning: Columns (16) have mixed types. Specify dtype option on import or set low_memory=False.\n",
32 | " gisaid_meta = pd.read_csv(\"results/gisaid/metadata_2022_07_27.tsv.gz\", sep=\"\\t\")\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "gisaid_meta = pd.read_csv(\"results/gisaid/metadata_2022_07_27.tsv.gz\", sep=\"\\t\")"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 7,
43 | "id": "9b3447d7-13d6-4f76-a922-8f564259bd85",
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "gisaid_meta[\"vname\"] = gisaid_meta[\"Virus name\"].str.replace(\"hCoV-19/\",\"\")\n",
48 | "gisaid_meta[\"vname2\"] = gisaid_meta[\"vname\"]"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 8,
54 | "id": "2b551436-24ed-4949-b0d2-ac863e64a5f0",
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "epi_map = gisaid_meta[[\"Accession ID\", \"vname\", \"vname2\", \"Collection date\"]]"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": null,
64 | "id": "a20539c6-4843-4107-8e0c-f7e621103659",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "epi_map = epi_map.sort_values(by=\"Accession ID\", ascending = True)"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "id": "66119639-6cad-4bd0-aa41-7d0d161c02f7",
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "epi_map"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": null,
84 | "id": "c8564631-e7bc-4907-8790-1935b48df902",
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "epi_map.to_csv(\"results/gisaid/epiToPublicAndDate.latest\", header=False, sep=\"\\t\", index=False)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "id": "017c3a6d-5fbc-444d-80cd-5fd8ccc8324d",
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "?pd.to_csv"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "id": "5b63f7f3-3742-4145-a873-d21ee6ec6f70",
105 | "metadata": {},
106 | "outputs": [],
107 | "source": [
108 | "?pd.DataFrame.to_csv"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "id": "6bc024e5-1f2c-4598-9943-6d6a77be2bd8",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": []
118 | }
119 | ],
120 | "metadata": {
121 | "kernelspec": {
122 | "display_name": "Python 3 (ipykernel)",
123 | "language": "python",
124 | "name": "python3"
125 | },
126 | "language_info": {
127 | "codemirror_mode": {
128 | "name": "ipython",
129 | "version": 3
130 | },
131 | "file_extension": ".py",
132 | "mimetype": "text/x-python",
133 | "name": "python",
134 | "nbconvert_exporter": "python",
135 | "pygments_lexer": "ipython3",
136 | "version": "3.9.12"
137 | }
138 | },
139 | "nbformat": 4,
140 | "nbformat_minor": 5
141 | }
142 |
--------------------------------------------------------------------------------
/generate_epiToPublicAndDate.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | #!/usr/bin/env python
5 | # coding: utf-8
6 |
7 | # Preprocess files for parsing
8 |
9 | # In[1]:
10 |
11 |
12 | import pandas as pd
13 |
14 | # In[2]:
15 |
16 |
17 | gisaid_meta = pd.read_csv("results/gisaid/metadata_2022_08_08.tsv.gz", sep="\t")
18 |
19 |
20 | # In[3]:
21 |
22 |
23 | gisaid_meta["vname"] = gisaid_meta["Virus name"].str.replace("hCoV-19/", "")
24 | gisaid_meta["vname2"] = gisaid_meta["vname"]
25 |
26 |
27 | # In[4]:
28 |
29 |
30 | epi_map = gisaid_meta[["Accession ID", "vname", "vname2", "Collection date"]]
31 |
32 |
33 | # In[5]:
34 |
35 |
36 | epi_map = epi_map.sort_values(by="Accession ID", ascending=True)
37 |
38 |
39 | # In[6]:
40 |
41 |
42 | epi_map
43 |
44 |
45 | # In[7]:
46 |
47 |
48 | epi_map.to_csv(
49 | "results/gisaid/epiToPublicAndDate.latest", header=False, sep="\t", index=False
50 | )
51 |
52 |
53 | # In[8]:
54 |
55 |
56 | # get_ipython().run_line_magic('pinfo', 'pd.to_csv')
57 |
58 |
59 |
60 | # In[ ]:
61 |
--------------------------------------------------------------------------------
/notebooks/contrast-strains.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## What are the salient differences between strains?\n",
8 | "\n",
9 | "This notebook addresses a [question of Tom Wenseleers](https://twitter.com/TWenseleers/status/1438780125479329792) about the salient differences between two strains, say between B.1.617.1 and the similar B.1.617.2 and B.1.617.3. You should be able to run this notebook merely after git cloning, to explore other salient differences."
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "First load the precomputed data."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 1,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "import pandas as pd\n",
26 | "strains_df = pd.read_csv(\"paper/strains.tsv\", sep=\"\\t\")\n",
27 | "mutations_df = pd.read_csv(\"paper/mutations.tsv\", sep=\"\\t\")"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {},
33 | "source": [
34 | "Convert to dictionaries for easier use."
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 2,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "mutations_per_strain = {\n",
44 | " strain: frozenset(mutations.split(\",\"))\n",
45 | " for strain, mutations in zip(strains_df[\"strain\"], strains_df[\"mutations\"])\n",
46 | "}"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 3,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "effect_size = dict(zip(mutations_df[\"mutation\"], mutations_df[\"Δ log R\"]))"
56 | ]
57 | },
58 | {
59 | "cell_type": "markdown",
60 | "metadata": {},
61 | "source": [
62 | "Create a helper to explore pairwise differences."
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 4,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "def print_diff(strain1, strain2, max_results=10):\n",
72 | " mutations1 = mutations_per_strain[strain1]\n",
73 | " mutations2 = mutations_per_strain[strain2]\n",
74 | " diff = [(m, effect_size[m]) for m in mutations1 ^ mutations2]\n",
75 | " diff.sort(key=lambda me: -abs(me[1]))\n",
76 | " print(f\"{strain1} versus {strain2}\")\n",
77 | " print(\"AA mutation Δ log R Present in strain\")\n",
78 | " print(\"--------------------------------------------\")\n",
79 | " for m, e in diff[:max_results]:\n",
80 | " strain = strain1 if m in mutations1 else strain2\n",
81 | " print(f\"{m: <15s} {e: <10.3g} {strain}\")"
82 | ]
83 | },
84 | {
85 | "cell_type": "markdown",
86 | "metadata": {},
87 | "source": [
88 | "Examine some example sequences."
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "B.1.617.2 versus B.1.617.1\n",
101 | "AA mutation Δ log R Present in strain\n",
102 | "--------------------------------------------\n",
103 | "ORF8:L84S 0.235 B.1.617.1\n",
104 | "ORF1a:S318L 0.206 B.1.617.1\n",
105 | "ORF1a:G392D 0.111 B.1.617.1\n",
106 | "ORF1b:P314F 0.11 B.1.617.1\n",
107 | "N:A220V 0.0752 B.1.617.2\n",
108 | "S:E484Q 0.0703 B.1.617.1\n",
109 | "S:A222V 0.0677 B.1.617.2\n",
110 | "ORF1a:M585V -0.0624 B.1.617.2\n",
111 | "ORF1a:S2535L 0.0612 B.1.617.1\n",
112 | "N:S194L 0.0599 B.1.617.1\n"
113 | ]
114 | }
115 | ],
116 | "source": [
117 | "print_diff(\"B.1.617.2\", \"B.1.617.1\")"
118 | ]
119 | },
120 | {
121 | "cell_type": "code",
122 | "execution_count": 6,
123 | "metadata": {},
124 | "outputs": [
125 | {
126 | "name": "stdout",
127 | "output_type": "stream",
128 | "text": [
129 | "B.1.617.2 versus B.1.617.3\n",
130 | "AA mutation Δ log R Present in strain\n",
131 | "--------------------------------------------\n",
132 | "ORF1a:K680N 0.415 B.1.617.3\n",
133 | "N:A220V 0.0752 B.1.617.2\n",
134 | "S:E484Q 0.0703 B.1.617.3\n",
135 | "S:A222V 0.0677 B.1.617.2\n",
136 | "ORF1a:M585V -0.0624 B.1.617.2\n",
137 | "N:S187L 0.0596 B.1.617.3\n",
138 | "ORF1a:D2980N 0.0498 B.1.617.2\n",
139 | "ORF1a:M3655I 0.0413 B.1.617.3\n",
140 | "ORF1a:P309L 0.0409 B.1.617.2\n",
141 | "ORF1a:S3675- 0.0409 B.1.617.3\n"
142 | ]
143 | }
144 | ],
145 | "source": [
146 | "print_diff(\"B.1.617.2\", \"B.1.617.3\")"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": []
155 | }
156 | ],
157 | "metadata": {
158 | "kernelspec": {
159 | "display_name": "Python 3",
160 | "language": "python",
161 | "name": "python3"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.7.0"
174 | }
175 | },
176 | "nbformat": 4,
177 | "nbformat_minor": 4
178 | }
179 |
--------------------------------------------------------------------------------
/notebooks/explore-nextclade.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import json\n",
10 | "import pandas as pd"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": null,
16 | "metadata": {},
17 | "outputs": [],
18 | "source": [
19 | "with open(\"results/gisaid.subset.json\", \"r\") as f:\n",
20 | " data = json.load(f)"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "data[0].keys()"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": null,
35 | "metadata": {},
36 | "outputs": [],
37 | "source": [
38 | "print(data[0]['seqName'])\n",
39 | "print(data[0]['alignmentScore'])"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "data[0]['aaSubstitutions']"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "data[0]['aaDeletions']"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "df = pd.read_csv(\"results/gisaid.subset.tsv\", sep=\"\\t\")"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "df"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "df.columns"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": null,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "df[\"aaSubstitutions\"].tolist()"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": null,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "df[\"aaDeletions\"]"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "df[\"aaDeletions\"][0]"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "len(df)"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "len(df.columns)"
130 | ]
131 | },
132 | {
133 | "cell_type": "code",
134 | "execution_count": null,
135 | "metadata": {},
136 | "outputs": [],
137 | "source": []
138 | }
139 | ],
140 | "metadata": {
141 | "kernelspec": {
142 | "display_name": "Python 3 (ipykernel)",
143 | "language": "python",
144 | "name": "python3"
145 | },
146 | "language_info": {
147 | "codemirror_mode": {
148 | "name": "ipython",
149 | "version": 3
150 | },
151 | "file_extension": ".py",
152 | "mimetype": "text/x-python",
153 | "name": "python",
154 | "nbconvert_exporter": "python",
155 | "pygments_lexer": "ipython3",
156 | "version": "3.9.1"
157 | }
158 | },
159 | "nbformat": 4,
160 | "nbformat_minor": 4
161 | }
162 |
--------------------------------------------------------------------------------
/notebooks/explore-owid-vaccinations.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import datetime\n",
11 | "import pandas as pd\n",
12 | "import torch\n",
13 | "import matplotlib.pyplot as plt"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "path = os.path.expanduser(\"~/github/owid/covid-19-data/public/data/vaccinations/vaccinations.csv\")\n",
23 | "df = pd.read_csv(path, header=0)\n",
24 | "df"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "portion = df[\"total_vaccinations_per_hundred\"].to_numpy() / 100\n",
34 | "portion"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": null,
40 | "metadata": {},
41 | "outputs": [],
42 | "source": [
43 | "locations = df[\"location\"].to_list()\n",
44 | "location_id = sorted(set(locations))\n",
45 | "location_id = {name: i for i, name in enumerate(location_id)}\n",
46 | "locations = torch.tensor([location_id[n] for n in locations], dtype=torch.long)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "def parse_date(s):\n",
56 | " return datetime.datetime.strptime(s, \"%Y-%m-%d\")\n",
57 | "\n",
58 | "start_date = parse_date(\"2020-12-01\")\n",
59 | "dates = torch.tensor([(parse_date(d) - start_date).days\n",
60 | " for d in df[\"date\"]])\n",
61 | "assert dates.min() >= 0"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": null,
67 | "metadata": {},
68 | "outputs": [],
69 | "source": [
70 | "T = int(1 + dates.max())\n",
71 | "R = len(location_id)\n",
72 | "dense_portion = torch"
73 | ]
74 | },
75 | {
76 | "cell_type": "code",
77 | "execution_count": null,
78 | "metadata": {},
79 | "outputs": [],
80 | "source": [
81 | "plt.scatter(dates, portion, s=10, alpha=0.5);"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": []
90 | }
91 | ],
92 | "metadata": {
93 | "kernelspec": {
94 | "display_name": "Python 3",
95 | "language": "python",
96 | "name": "python3"
97 | },
98 | "language_info": {
99 | "codemirror_mode": {
100 | "name": "ipython",
101 | "version": 3
102 | },
103 | "file_extension": ".py",
104 | "mimetype": "text/x-python",
105 | "name": "python",
106 | "nbconvert_exporter": "python",
107 | "pygments_lexer": "ipython3",
108 | "version": "3.6.12"
109 | }
110 | },
111 | "nbformat": 4,
112 | "nbformat_minor": 4
113 | }
114 |
--------------------------------------------------------------------------------
/notebooks/explore_pairs_of_lineages.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Compare pairs of lineages w.r.t. mutational profiles and determinants of transmissibility"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import pickle\n",
17 | "import numpy as np\n",
18 | "import matplotlib\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "import torch\n",
21 | "from pyrocov import pangolin\n",
22 | "import pandas as pd\n",
23 | "\n",
24 | "matplotlib.rcParams[\"figure.dpi\"] = 200\n",
25 | "matplotlib.rcParams[\"axes.edgecolor\"] = \"gray\"\n",
26 | "matplotlib.rcParams[\"savefig.bbox\"] = \"tight\"\n",
27 | "matplotlib.rcParams['font.family'] = 'sans-serif'\n",
28 | "matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 2,
34 | "metadata": {},
35 | "outputs": [
36 | {
37 | "name": "stdout",
38 | "output_type": "stream",
39 | "text": [
40 | "dict_keys(['location_id', 'mutations', 'weekly_clades', 'features', 'lineage_id', 'lineage_id_inv', 'time_shift'])\n"
41 | ]
42 | }
43 | ],
44 | "source": [
45 | "dataset = torch.load(\"results/mutrans.data.single.None.pt\", map_location=\"cpu\")\n",
46 | "print(dataset.keys())\n",
47 | "locals().update(dataset)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": 3,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "features = dataset['features']\n",
57 | "coefficients = pd.read_csv(\"paper/mutations.tsv\", sep=\"\\t\", index_col=1)\n",
58 | "coefficients = coefficients.loc[dataset['mutations']].copy()\n",
59 | "feature_names = coefficients.index.values.tolist()\n",
60 | "\n",
61 | "lineage_id = {name: i for i, name in enumerate(lineage_id_inv)}\n",
62 | "lineage_id_inv = dataset['lineage_id_inv']\n",
63 | "\n",
64 | "deltaR = coefficients['Δ log R'].values\n",
65 | "zscore = coefficients['mean/stddev'].values"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": 4,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "##########################################\n",
75 | "### select pair of lineages to compare ###\n",
76 | "##########################################\n",
77 | "A, B = 'B.1.617.2', 'B.1'\n",
78 | "#A, B = 'B.1.1.7', 'B.1.1'\n",
79 | "#A, B = 'B.1.427', 'B.1'\n",
80 | "#A, B = 'B.1.351', 'B.1'\n",
81 | "#A, B = 'P.1', 'B.1.1'\n",
82 | "#A, B = 'AY.2', 'B.1.617.2'\n",
83 | "\n",
84 | "A_id, B_id = lineage_id[A], lineage_id[B]\n",
85 | "A_feat, B_feat = features[A_id].numpy(), features[B_id].numpy()\n",
86 | "\n",
87 | "delta_cov = A_feat - B_feat\n",
88 | "delta_cov_abs = np.fabs(A_feat - B_feat)"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": 5,
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "name": "stdout",
98 | "output_type": "stream",
99 | "text": [
100 | "deltaR_cutoff 0.042476400000000004\n",
101 | "ORF1a:P2287S \t deltaR: 0.044 zscore: 253.71 \t\t delta_feature: 0.80\n",
102 | "ORF1a:T3255I \t deltaR: 0.045 zscore: 240.18 \t\t delta_feature: 0.80\n",
103 | "S:L452R \t deltaR: 0.048 zscore: 244.88 \t\t delta_feature: 0.98\n",
104 | "S:P681R \t deltaR: 0.051 zscore: 300.61 \t\t delta_feature: 0.97\n",
105 | "\n",
106 | " B.1.617.2 over B.1\n",
107 | "ORF1a:P2287S, ORF1a:T3255I, S:L452R, S:P681R, "
108 | ]
109 | }
110 | ],
111 | "source": [
112 | "# look at top 100 mutations w.r.t. effect size\n",
113 | "deltaR_cutoff = np.fabs(deltaR)[np.argsort(np.fabs(deltaR))[-100]]\n",
114 | "print(\"deltaR_cutoff\", deltaR_cutoff)\n",
115 | "\n",
116 | "selected_features = []\n",
117 | "\n",
118 | "for i, name in enumerate(feature_names):\n",
119 | " if len(name) <= 6:\n",
120 | " name = name + \" \"\n",
121 | " dR = deltaR[i]\n",
122 | " dC = delta_cov[i]\n",
123 | " z = zscore[i]\n",
124 | " if dR > deltaR_cutoff and np.fabs(dC) > 0.5:\n",
125 | " selected_features.append(name)\n",
126 | " print(\"{} \\t deltaR: {:.3f} zscore: {:.2f} \\t\\t delta_feature: {:.2f}\".format(name, dR, z, dC))\n",
127 | "\n",
128 | "print(\"\\n\", A, \"over\" ,B)\n",
129 | "for s in selected_features:\n",
130 | " print(s + \", \", end='')"
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "execution_count": 6,
136 | "metadata": {},
137 | "outputs": [
138 | {
139 | "name": "stdout",
140 | "output_type": "stream",
141 | "text": [
142 | "M:I82T \t deltaR: 0.038 zscore: 230.30 \t\t delta_feature: 1.00\n",
143 | "N:D63G \t deltaR: 0.028 zscore: 201.31 \t\t delta_feature: 0.99\n",
144 | "ORF1a:P2287S \t deltaR: 0.044 zscore: 253.71 \t\t delta_feature: 0.80\n",
145 | "ORF1a:T3255I \t deltaR: 0.045 zscore: 240.18 \t\t delta_feature: 0.80\n",
146 | "ORF1b:G662S \t deltaR: 0.027 zscore: 228.74 \t\t delta_feature: 0.98\n",
147 | "ORF1b:P1000L \t deltaR: 0.035 zscore: 211.10 \t\t delta_feature: 0.98\n",
148 | "S:D950N \t deltaR: 0.036 zscore: 231.28 \t\t delta_feature: 0.98\n",
149 | "S:E156- \t deltaR: 0.028 zscore: 203.85 \t\t delta_feature: 0.95\n",
150 | "S:L452R \t deltaR: 0.048 zscore: 244.88 \t\t delta_feature: 0.98\n",
151 | "S:P681R \t deltaR: 0.051 zscore: 300.61 \t\t delta_feature: 0.97\n",
152 | "\n",
153 | " B.1.617.2 over B.1\n",
154 | "M:I82T , N:D63G , ORF1a:P2287S, ORF1a:T3255I, ORF1b:G662S, ORF1b:P1000L, S:D950N, S:E156-, S:L452R, S:P681R, "
155 | ]
156 | }
157 | ],
158 | "source": [
159 | "selected_features = []\n",
160 | "\n",
161 | "# look at large z-score mutations (i.e. increase growth rate)\n",
162 | "for i, name in enumerate(feature_names):\n",
163 | " if len(name) <= 6:\n",
164 | " name = name + \" \"\n",
165 | " dR = deltaR[i]\n",
166 | " dC = delta_cov[i]\n",
167 | " z = zscore[i]\n",
168 | " if z > 200.0 and np.fabs(dC) > 0.5:\n",
169 | " selected_features.append(name)\n",
170 | " print(\"{} \\t deltaR: {:.3f} zscore: {:.2f} \\t\\t delta_feature: {:.2f}\".format(name, dR, z, dC))\n",
171 | "\n",
172 | "print(\"\\n\", A, \"over\" ,B)\n",
173 | "for s in selected_features:\n",
174 | " print(s + \", \", end='')"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 7,
180 | "metadata": {},
181 | "outputs": [],
182 | "source": [
183 | "def findfeat(s):\n",
184 | " for i, n in enumerate(feature_names):\n",
185 | " if n==s:\n",
186 | " return i\n",
187 | " return -1"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 8,
193 | "metadata": {},
194 | "outputs": [
195 | {
196 | "name": "stdout",
197 | "output_type": "stream",
198 | "text": [
199 | "0.0054945056\n",
200 | "0.96954316\n"
201 | ]
202 | }
203 | ],
204 | "source": [
205 | "print(features[lineage_id['B.1']].numpy()[findfeat('S:H69-')])\n",
206 | "print(features[lineage_id['B.1.1.7']].numpy()[findfeat('S:H69-')])"
207 | ]
208 | },
209 | {
210 | "cell_type": "code",
211 | "execution_count": 9,
212 | "metadata": {},
213 | "outputs": [
214 | {
215 | "data": {
216 | "text/plain": [
217 | "array([[1.73514e-02, 2.10334e+02]])"
218 | ]
219 | },
220 | "execution_count": 9,
221 | "metadata": {},
222 | "output_type": "execute_result"
223 | }
224 | ],
225 | "source": [
226 | "coefficients[coefficients.index == 'S:H69-'][['Δ log R', 'mean/stddev']].values"
227 | ]
228 | }
229 | ],
230 | "metadata": {
231 | "kernelspec": {
232 | "display_name": "Python 3",
233 | "language": "python",
234 | "name": "python3"
235 | },
236 | "language_info": {
237 | "codemirror_mode": {
238 | "name": "ipython",
239 | "version": 3
240 | },
241 | "file_extension": ".py",
242 | "mimetype": "text/x-python",
243 | "name": "python",
244 | "nbconvert_exporter": "python",
245 | "pygments_lexer": "ipython3",
246 | "version": "3.8.3"
247 | }
248 | },
249 | "nbformat": 4,
250 | "nbformat_minor": 4
251 | }
252 |
--------------------------------------------------------------------------------
/notebooks/explore_pango_lineages.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from pyrocov import pangolin"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "for k, v in sorted(pangolin.PANGOLIN_ALIASES.items()):\n",
19 | " print(f\"{k}\\t{v}\")"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {},
26 | "outputs": [],
27 | "source": [
28 | "pangolin.update_aliases()"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": []
37 | }
38 | ],
39 | "metadata": {
40 | "kernelspec": {
41 | "display_name": "Python 3",
42 | "language": "python",
43 | "name": "python3"
44 | },
45 | "language_info": {
46 | "codemirror_mode": {
47 | "name": "ipython",
48 | "version": 3
49 | },
50 | "file_extension": ".py",
51 | "mimetype": "text/x-python",
52 | "name": "python",
53 | "nbconvert_exporter": "python",
54 | "pygments_lexer": "ipython3",
55 | "version": "3.6.12"
56 | }
57 | },
58 | "nbformat": 4,
59 | "nbformat_minor": 4
60 | }
61 |
--------------------------------------------------------------------------------
/notebooks/grid_search.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Analyzing results of grid search\n",
8 | "\n",
9 | "This notebook assumes you've downloaded data and run a grid search experiment\n",
10 | "```sh\n",
11 | "make update # many hours\n",
12 | "python mutrans.py --grid-search # many hours\n",
13 | "```"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import math\n",
23 | "import re\n",
24 | "import numpy as np\n",
25 | "import pandas as pd\n",
26 | "import matplotlib\n",
27 | "import matplotlib.pyplot as plt\n",
28 | "from pyrocov.util import pearson_correlation\n",
29 | "from pyrocov.plotting import force_apart\n",
30 | "\n",
31 | "matplotlib.rcParams[\"figure.dpi\"] = 200"
32 | ]
33 | },
34 | {
35 | "cell_type": "code",
36 | "execution_count": null,
37 | "metadata": {
38 | "scrolled": false
39 | },
40 | "outputs": [],
41 | "source": [
42 | "df = pd.read_csv(\"results/grid_search.tsv\", sep=\"\\t\")\n",
43 | "df = df.fillna(\"\")\n",
44 | "df"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "df.columns"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": null,
59 | "metadata": {
60 | "scrolled": false
61 | },
62 | "outputs": [],
63 | "source": [
64 | "model_type = df[\"model_type\"].to_list()\n",
65 | "cond_data = df[\"cond_data\"].to_list()\n",
66 | "mutation_corr = df[\"mutation_corr\"].to_numpy()\n",
67 | "mutation_error = df[\"mutation_rmse\"].to_numpy() / df[\"mutation_stddev\"].to_numpy()\n",
68 | "mae_pred = df[\"England B.1.1.7 MAE\"].to_numpy()\n",
69 | "\n",
70 | "loss = df[\"loss\"].to_numpy()\n",
71 | "min_loss, max_loss = loss.min(), loss.max()\n",
72 | "assert (loss > 0).all(), \"you'll need to switch to symlog or sth\"\n",
73 | "loss = np.log(loss)\n",
74 | "loss -= loss.min()\n",
75 | "loss /= loss.max()\n",
76 | "R_alpha = df[\"R(B.1.1.7)/R(A)\"].to_numpy()\n",
77 | "R_delta = df[\"R(B.1.617.2)/R(A)\"].to_numpy()\n",
78 | "\n",
79 | "def plot_concordance(filenames=[], colorby=\"R\"):\n",
80 | " legend = {}\n",
81 | " def abbreviate_param(match):\n",
82 | " k = match.group()[:-1]\n",
83 | " v = k[0].upper()\n",
84 | " legend[v] = k\n",
85 | " return v\n",
86 | " def abbreviate_sample(match):\n",
87 | " k = match.group()[:-1]\n",
88 | " v = k[0]\n",
89 | " legend[v] = k\n",
90 | " return v + \"꞊\"\n",
91 | " fig, axes = plt.subplots(2, figsize=(8, 12))\n",
92 | " for ax, X, Y, xlabel, ylabel in zip(\n",
93 | " axes, [mutation_error, R_alpha], [mae_pred, R_delta],\n",
94 | " [\n",
95 | " # \"Pearson correlation of mutaitons\",\n",
96 | " \"Cross-validation error of mutation coefficients (lower is better)\",\n",
97 | " \"R(α) / R(A)\"],\n",
98 | " [\"England α portion MAE (lower is better)\", \"R(δ) / R(A)\"]\n",
99 | " ):\n",
100 | " ax.scatter(X, Y, 30, loss, lw=0, alpha=0.8, cmap=\"coolwarm\")\n",
101 | " ax.set_xlabel(xlabel)\n",
102 | " ax.set_ylabel(ylabel)\n",
103 | " \n",
104 | " X_, Y_ = force_apart(X, Y, stepsize=2)\n",
105 | " assert X_.dim() == 1\n",
106 | " X_X = []\n",
107 | " Y_Y = []\n",
108 | " for x_, x, y_, y in zip(X_, X, Y_, Y):\n",
109 | " X_X.extend([float(x_), float(x), None])\n",
110 | " Y_Y.extend([float(y_), float(y), None])\n",
111 | " ax.plot(X_X, Y_Y, \"k-\", lw=0.5, alpha=0.5, zorder=-10)\n",
112 | " for x, y, mt, cd, l in zip(X_, Y_, model_type, cond_data, loss):\n",
113 | " name = f\"{mt}-{cd}\"\n",
114 | " name = re.sub(\"[a-z_]+-\", abbreviate_param, name)\n",
115 | " name = re.sub(\"[a-z_]+=\", abbreviate_sample, name)\n",
116 | " name = name.replace(\"-\", \"\")\n",
117 | " ax.text(x, y, name, fontsize=7, va=\"center\", alpha=1 - 0.666 * l)\n",
118 | " \n",
119 | " axes[0].set_xscale(\"log\")\n",
120 | " axes[0].set_yscale(\"log\")\n",
121 | " axes[0].plot([], [], \"bo\", markeredgewidth=0, markersize=5, alpha=0.5,\n",
122 | " label=f\"loss={min_loss:0.2g} (better)\")\n",
123 | " axes[0].plot([], [], \"ro\", markeredgewidth=0, markersize=5, alpha=0.5,\n",
124 | " label=f\"loss={max_loss:0.2g} (worse)\")\n",
125 | " for k, v in sorted(legend.items()):\n",
126 | " axes[0].plot([], [], \"wo\", label=f\"{k} = {v}\")\n",
127 | " axes[0].legend(loc=\"upper right\", fontsize=\"small\")\n",
128 | " min_max = [max(X.min(), Y.min()), min(X.max(), Y.max())]\n",
129 | " axes[1].plot(min_max, min_max, \"k--\", alpha=0.2, zorder=-10)\n",
130 | " plt.subplots_adjust(hspace=0.15)\n",
131 | " for filename in filenames:\n",
132 | " plt.savefig(filename)\n",
133 | " \n",
134 | "plot_concordance([\"paper/grid_search.png\"])"
135 | ]
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": null,
140 | "metadata": {
141 | "scrolled": false
142 | },
143 | "outputs": [],
144 | "source": [
145 | "import torch\n",
146 | "grid = torch.load(\"results/mutrans.grid.pt\")\n",
147 | "\n",
148 | "def plot_mutation_agreements(grid):\n",
149 | " fig, axes = plt.subplots(len(grid), 3, figsize=(8, 1 + 3 * len(grid)))\n",
150 | " for axe, (name, holdouts) in zip(axes, sorted(grid.items())):\n",
151 | " (name0, fit0), (name1, fit1), (name2, fit2) = holdouts.items()\n",
152 | " pairs = [\n",
153 | " [(name0, fit0), (name1, fit1)],\n",
154 | " [(name0, fit0), (name2, fit2)],\n",
155 | " [(name1, fit1), (name2, fit2)],\n",
156 | " ]\n",
157 | " means = [v[\"coef\"] * 0.01 for v in holdouts.values()]\n",
158 | " x0 = min(mean.min().item() for mean in means)\n",
159 | " x1 = max(mean.max().item() for mean in means)\n",
160 | " lb = 1.05 * x0 - 0.05 * x1\n",
161 | " ub = 1.05 * x1 - 0.05 * x0\n",
162 | " axe[1].set_title(str(name))\n",
163 | " axe[0].set_ylabel(str(name).replace(\"-\", \"\\n\").replace(\",\", \"\\n\"), fontsize=8)\n",
164 | " for ax, ((name1, fit1), (name2, fit2)) in zip(axe, pairs):\n",
165 | " mutations = sorted(set(fit1[\"mutations\"]) & set(fit2[\"mutations\"]))\n",
166 | " means = []\n",
167 | " for fit in (fit1, fit2):\n",
168 | " m_to_i = {m: i for i, m in enumerate(fit[\"mutations\"])}\n",
169 | " idx = torch.tensor([m_to_i[m] for m in mutations])\n",
170 | " means.append(fit[\"coef\"])\n",
171 | " ax.plot([lb, ub], [lb, ub], 'k--', alpha=0.3, zorder=-100)\n",
172 | " ax.scatter(means[1].numpy(), means[0].numpy(), 30, alpha=0.3, lw=0, color=\"darkred\")\n",
173 | " ax.axis(\"equal\")\n",
174 | " ax.set_title(\"ρ = {:0.2g}\".format(pearson_correlation(means[0], means[1])))\n",
175 | "plot_mutation_agreements(grid)"
176 | ]
177 | },
178 | {
179 | "cell_type": "markdown",
180 | "metadata": {},
181 | "source": [
182 | "## Debugging plotting code"
183 | ]
184 | },
185 | {
186 | "cell_type": "code",
187 | "execution_count": null,
188 | "metadata": {
189 | "scrolled": false
190 | },
191 | "outputs": [],
192 | "source": [
193 | "from pyrocov.plotting import force_apart\n",
194 | "torch.manual_seed(1234567890)\n",
195 | "X, Y = torch.randn(2, 200)\n",
196 | "X_, Y_ = force_apart(X, Y)\n",
197 | "plt.plot(X, Y, \"ko\")\n",
198 | "for i in range(8):\n",
199 | " plt.plot(X_ + i / 20, Y_, \"r.\")\n",
200 | "X_X = []\n",
201 | "Y_Y = []\n",
202 | "for x_, x, y_, y in zip(X_, X, Y_, Y):\n",
203 | " X_X.extend([float(x_), float(x), None])\n",
204 | " Y_Y.extend([float(y_), float(y), None])\n",
205 | "plt.plot(X_X, Y_Y, \"k-\", lw=0.5, alpha=0.5, zorder=-10);"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": []
214 | }
215 | ],
216 | "metadata": {
217 | "kernelspec": {
218 | "display_name": "Python 3",
219 | "language": "python",
220 | "name": "python3"
221 | },
222 | "language_info": {
223 | "codemirror_mode": {
224 | "name": "ipython",
225 | "version": 3
226 | },
227 | "file_extension": ".py",
228 | "mimetype": "text/x-python",
229 | "name": "python",
230 | "nbconvert_exporter": "python",
231 | "pygments_lexer": "ipython3",
232 | "version": "3.6.12"
233 | }
234 | },
235 | "nbformat": 4,
236 | "nbformat_minor": 4
237 | }
238 |
--------------------------------------------------------------------------------
/notebooks/mutrans_loo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Predictive accuracy of mutrans model on new lineages\n",
8 | "\n",
9 | "This notebook assumes you have run\n",
10 | "```sh\n",
11 | "make update\n",
12 | "make preprocess\n",
13 | "python mutrans.py --vary-leaves\n",
14 | "```"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import pickle\n",
24 | "import numpy as np\n",
25 | "import matplotlib\n",
26 | "import matplotlib.pyplot as plt\n",
27 | "import torch\n",
28 | "from pyrocov import pangolin\n",
29 | "from pyrocov.util import pearson_correlation, quotient_central_moments\n",
30 | "\n",
31 | "matplotlib.rcParams[\"figure.dpi\"] = 200\n",
32 | "matplotlib.rcParams[\"axes.edgecolor\"] = \"gray\"\n",
33 | "matplotlib.rcParams[\"figure.facecolor\"] = \"white\"\n",
34 | "matplotlib.rcParams[\"savefig.bbox\"] = \"tight\"\n",
35 | "matplotlib.rcParams['font.family'] = 'sans-serif'\n",
36 | "matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": null,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "dataset = torch.load(\"results/mutrans.data.single.3000.1.50.None.pt\", map_location=\"cpu\")\n",
46 | "print(dataset.keys())\n",
47 | "locals().update(dataset)"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "lineage_id = dataset[\"lineage_id\"]\n",
57 | "clade_id_to_lineage_id = dataset[\"clade_id_to_lineage_id\"]"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "loo = torch.load(\"results/mutrans.vary_leaves.pt\", map_location=\"cpu\")\n",
67 | "print(len(loo))\n",
68 | "print(list(loo)[-1])\n",
69 | "print(list(loo.values())[-1].keys())"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "print(len(mutations))\n",
79 | "print(list(loo.values())[0][\"median\"][\"rate_loc\"].shape)"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {},
86 | "outputs": [],
87 | "source": [
88 | "best_rate_loc = None\n",
89 | "loo_rate_loc = {}\n",
90 | "for k, v in loo.items():\n",
91 | " rate = quotient_central_moments(v[\"median\"][\"rate_loc\"], clade_id_to_lineage_id)[1]\n",
92 | " holdout = k[-1]\n",
93 | " if holdout:\n",
94 | " key = holdout[-1][-1][-1][-1].replace(\"$\", \"\").replace(\"^\", \"\")\n",
95 | " loo_rate_loc[key] = rate\n",
96 | " else:\n",
97 | " best_rate_loc = rate\n",
98 | "print(\" \".join(loo_rate_loc))"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "def plot_prediction(filenames=[], debug=False, use_who=True):\n",
108 | " X1, Y1, X2, Y2, labels, debug_labels = [], [], [], [], [], []\n",
109 | " who = {vs[0]: k for k, vs in pangolin.WHO_ALIASES.items()}\n",
110 | " ancestors = set(lineage_id)\n",
111 | " for child, rate_loc in loo_rate_loc.items():\n",
112 | " parent = pangolin.compress(\n",
113 | " pangolin.get_most_recent_ancestor(\n",
114 | " pangolin.decompress(child), ancestors\n",
115 | " )\n",
116 | " )\n",
117 | " c = lineage_id[child]\n",
118 | " p = lineage_id[parent]\n",
119 | " truth = best_rate_loc[c].item()\n",
120 | " baseline = rate_loc[p].item()\n",
121 | " guess = rate_loc[c].item()\n",
122 | " Y1.append(truth)\n",
123 | " X1.append(guess)\n",
124 | " Y2.append(truth - baseline)\n",
125 | " X2.append(guess - baseline)\n",
126 | " labels.append(who.get(child))\n",
127 | " debug_labels.append(child)\n",
128 | " mae = np.abs(np.array(Y2)).mean()\n",
129 | " print(f\"MAE(baseline - full estimate) = {mae:0.4g}\")\n",
130 | " fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))\n",
131 | " for ax, X, Y in zip(axes, [X1, X2], [Y1, Y2]):\n",
132 | " X = np.array(X)\n",
133 | " Y = np.array(Y)\n",
134 | " ax.scatter(X, Y, 40, lw=0, alpha=1, color=\"white\", zorder=-5)\n",
135 | " ax.scatter(X, Y, 20, lw=0, alpha=0.3, color=\"darkred\")\n",
136 | " lb = min(min(X), min(Y))\n",
137 | " ub = max(max(X), max(Y))\n",
138 | " d = ub - lb\n",
139 | " lb -= 0.03 * d\n",
140 | " ub += 0.05 * d\n",
141 | " ax.plot([lb, ub], [lb, ub], \"k--\", alpha=0.2, zorder=-10)\n",
142 | " ax.set_xlim(lb, ub)\n",
143 | " ax.set_ylim(lb, ub)\n",
144 | " rho = pearson_correlation(X, Y)\n",
145 | " mae = np.abs(X - Y).mean()\n",
146 | " ax.text(0.3 * lb + 0.7 * ub, 0.8 * lb + 0.2 * ub,\n",
147 | " #f\" ρ = {rho:0.3f}\\nMAE = {mae:0.3g}\",\n",
148 | " f\" ρ = {rho:0.3f}\",\n",
149 | " backgroundcolor=\"white\", ha=\"center\", va=\"center\")\n",
150 | " for x, y, label, debug_label in zip(X, Y, labels, debug_labels):\n",
151 | " pad = 0.012\n",
152 | " if label is not None:\n",
153 | " ax.plot([x], [y], \"ko\", mfc=\"#c77\", c=\"black\", ms=4, mew=0.5)\n",
154 | " ax.text(x, y + pad, label if use_who else debug_label,\n",
155 | " va=\"bottom\", ha=\"center\", fontsize=6)\n",
156 | " elif abs(x - y) > 0.2:\n",
157 | " ax.plot([x], [y], \"ko\", mfc=\"#c77\", c=\"black\", ms=4, mew=0.5)\n",
158 | " ax.text(x, y + pad, debug_label, va=\"bottom\", ha=\"center\", fontsize=6)\n",
159 | " axes[0].set_ylabel(\"full estimate\")\n",
160 | " axes[0].set_xlabel(\"LOO estimate\")\n",
161 | " axes[1].set_ylabel(\"full estimate − baseline\")\n",
162 | " axes[1].set_xlabel(\"LOO estimate − baseline\")\n",
163 | " plt.tight_layout()\n",
164 | " for f in filenames:\n",
165 | " plt.savefig(f)\n",
166 | "plot_prediction(debug=True)\n",
167 | "plot_prediction(use_who=False, filenames=[\"paper/lineage_prediction.png\"])"
168 | ]
169 | },
170 | {
171 | "cell_type": "code",
172 | "execution_count": null,
173 | "metadata": {},
174 | "outputs": [],
175 | "source": []
176 | }
177 | ],
178 | "metadata": {
179 | "kernelspec": {
180 | "display_name": "Python 3",
181 | "language": "python",
182 | "name": "python3"
183 | },
184 | "language_info": {
185 | "codemirror_mode": {
186 | "name": "ipython",
187 | "version": 3
188 | },
189 | "file_extension": ".py",
190 | "mimetype": "text/x-python",
191 | "name": "python",
192 | "nbconvert_exporter": "python",
193 | "pygments_lexer": "ipython3",
194 | "version": "3.6.12"
195 | }
196 | },
197 | "nbformat": 4,
198 | "nbformat_minor": 4
199 | }
200 |
--------------------------------------------------------------------------------
/notebooks/paper:
--------------------------------------------------------------------------------
1 | ../paper
--------------------------------------------------------------------------------
/notebooks/preprocess-pairwise-allele-fractions.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "96af8117",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import torch"
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 2,
16 | "id": "7e9fa3b9",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "max_num_clades = 3000\n",
21 | "min_num_mutations = 1"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 3,
27 | "id": "eb793068",
28 | "metadata": {},
29 | "outputs": [
30 | {
31 | "name": "stdout",
32 | "output_type": "stream",
33 | "text": [
34 | "clade_id \tdict of size 3000\n",
35 | "clade_id_inv \tlist of size 3000\n",
36 | "clade_id_to_lineage_id \tTensor of shape (3000,)\n",
37 | "clade_to_lineage \tdict of size 3000\n",
38 | "features \tTensor of shape (3000, 2904)\n",
39 | "lineage_id \tdict of size 1544\n",
40 | "lineage_id_inv \tlist of size 1544\n",
41 | "lineage_id_to_clade_id \tTensor of shape (1544,)\n",
42 | "lineage_to_clade \tdict of size 1544\n",
43 | "location_id \tOrderedDict of size 1560\n",
44 | "location_id_inv \tlist of size 1560\n",
45 | "mutations \tlist of size 2904\n",
46 | "pc_index \tTensor of shape (185001,)\n",
47 | "sparse_counts \tdict of size 3\n",
48 | "state_to_country \tTensor of shape (1355,)\n",
49 | "time \tTensor of shape (56,)\n",
50 | "weekly_clades \tTensor of shape (56, 1560, 3000)\n",
51 | "CPU times: user 15 ms, sys: 891 ms, total: 906 ms\n",
52 | "Wall time: 1.46 s\n"
53 | ]
54 | }
55 | ],
56 | "source": [
57 | "%%time\n",
58 | "def load_data():\n",
59 | " filename = f\"results/mutrans.data.single.{max_num_clades}.{min_num_mutations}.50.None.pt\"\n",
60 | " dataset = torch.load(filename, map_location=\"cpu\")\n",
61 | " return dataset\n",
62 | "dataset = load_data()\n",
63 | "locals().update(dataset)\n",
64 | "for k, v in sorted(dataset.items()):\n",
65 | " if isinstance(v, torch.Tensor):\n",
66 | " print(f\"{k} \\t{type(v).__name__} of shape {tuple(v.shape)}\")\n",
67 | " else:\n",
68 | " print(f\"{k} \\t{type(v).__name__} of size {len(v)}\")"
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "execution_count": null,
74 | "id": "ffd0f8ed",
75 | "metadata": {},
76 | "outputs": [],
77 | "source": [
78 | "# TODO(martin) Use features, mutations, weekly_clades"
79 | ]
80 | }
81 | ],
82 | "metadata": {
83 | "kernelspec": {
84 | "display_name": "Python 3 (ipykernel)",
85 | "language": "python",
86 | "name": "python3"
87 | },
88 | "language_info": {
89 | "codemirror_mode": {
90 | "name": "ipython",
91 | "version": 3
92 | },
93 | "file_extension": ".py",
94 | "mimetype": "text/x-python",
95 | "name": "python",
96 | "nbconvert_exporter": "python",
97 | "pygments_lexer": "ipython3",
98 | "version": "3.9.1"
99 | }
100 | },
101 | "nbformat": 4,
102 | "nbformat_minor": 5
103 | }
104 |
--------------------------------------------------------------------------------
/notebooks/results:
--------------------------------------------------------------------------------
1 | ../results
--------------------------------------------------------------------------------
/paper/.gitignore:
--------------------------------------------------------------------------------
1 | *.aux
2 | *.bbl
3 | *.blg
4 | *.pdf
5 | *.out
6 |
--------------------------------------------------------------------------------
/paper/BA.1_logistic_regression_all_US.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/BA.1_logistic_regression_all_US.jpg
--------------------------------------------------------------------------------
/paper/Makefile:
--------------------------------------------------------------------------------
1 | all: supplement.pdf
2 |
3 | push:
4 | cp supplement.pdf ../data/
5 |
6 | supplement.pdf: supplement.tex FORCE
7 | pdflatex supplement
8 | bibtex supplement
9 | pdflatex supplement
10 |
11 | FORCE:
12 |
--------------------------------------------------------------------------------
/paper/N_S_peak_ratio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/N_S_peak_ratio.png
--------------------------------------------------------------------------------
/paper/N_S_ratio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/N_S_ratio.png
--------------------------------------------------------------------------------
/paper/README.md:
--------------------------------------------------------------------------------
1 | # Images and data for publication
2 |
3 | This directory contains figures and tables output by the [PyR0
4 | model](https://www.medrxiv.org/content/10.1101/2021.09.07.21263228v1). These
5 | outputs are aggregated to weeks, PANGO lineages, and amino acid changes.
6 |
7 | Figures and tables are generated by first running preprocessing and inference,
8 | then postprocessing with the following Jupyter notebooks:
9 | [ `mutrans.ipynb` ](../mutrans.ipynb),
10 | [ `mutrans_gene.ipynb` ](../mutrans_gene.ipynb),
11 | [ `mutrans_prediction.ipynb` ](../mutrans_prediction.ipynb),
12 | [ `mutrans_backtesting.ipynb` ](../mutrans_backtesting.ipynb).
13 |
14 | ## Data tables
15 |
16 | - [Mutation table](mutations.tsv) is ranked by statistical significance.
17 | The "mean" field denotes the estimated effect on log growth rate of each mutation.
18 | - [Lineage table](strains.tsv) is ranked by growth rate.
19 |
20 | ## Manhattan plots
21 |
22 | 
23 | 
24 | 
25 | 
26 | 
27 | 
28 |
29 | ## Information density plots
30 |
31 | 
32 | 
33 |
34 | ## Volcano plot
35 |
36 | 
37 |
38 | ## Strain characterization plots
39 |
40 | 
41 | 
42 | 
43 | 
44 | 
45 |
46 | ## Cross validation plots
47 |
48 | The following plots assess robustness via 2-fold crossvalidation, splitting data into Europe versus (World w/o Europe).
49 |
50 | 
51 | 
52 | 
53 | 
54 | 
55 | 
56 |
57 | ## Data plots
58 |
59 | 
60 | 
61 |
62 | ## Acknowledgements
63 |
64 | The aggregated model outputs in this directory were generated from data inputs
65 | including either GISAID records (https://gisaid.org), an UShER tree placement
66 | of those records
67 | (http://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/UShER_SARS-CoV-2), PANGO
68 | lineage classifications (https://cov-lineages.org), and case count time series
69 | from Johns-Hopkins University (https://github.com/CSSEGISandData/COVID-19).
70 | Results in this directory can alternatively be generated using GENBANK records
71 | (https://www.ncbi.nlm.nih.gov) instead of GISAID records.
72 |
73 | We gratefully acknowledge all data contributors, i.e. the Authors and their Originating laboratories responsible for obtaining the specimens, and their Submitting laboratories for generating the genetic sequence and metadata and sharing via the GISAID initiative [1,2] on which this research is based. A total of 6,466,299 submissions are included in this study. A complete list of the 6,466,299 accession numbers is available in [accession_ids.txt.xz](accession_ids.txt.xz).
74 |
75 | 1. GISAID Initiative and global contributors,
76 | EpiCoV(TM) human coronavirus 2019 database.
77 | GISAID (2020), (available at https://gisaid.org).
78 | 2. S. Elbe, G. Buckland-Merrett,
79 | Data, disease and diplomacy: GISAID's innovative contribution to global health.
80 | Glob Chall. 1, 33-46 (2017).
81 | 3. National Center for Biotechnology Information (NCBI)[Internet].
82 | Bethesda (MD): National Library of Medicine (US),
83 | National Center for Biotechnology Information;
84 | [1988] – [cited 2017 Apr 06].
85 | Available from: https://www.ncbi.nlm.nih.gov
86 |
--------------------------------------------------------------------------------
/paper/accession_ids.txt.xz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/accession_ids.txt.xz
--------------------------------------------------------------------------------
/paper/backtesting/L1_error_barplot_England.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_England.png
--------------------------------------------------------------------------------
/paper/backtesting/L1_error_barplot_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_all.png
--------------------------------------------------------------------------------
/paper/backtesting/L1_error_barplot_other.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_other.png
--------------------------------------------------------------------------------
/paper/backtesting/L1_error_barplot_top100-1000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_top100-1000.png
--------------------------------------------------------------------------------
/paper/backtesting/L1_error_barplot_top100.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/L1_error_barplot_top100.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_248_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_248_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_346_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_346_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_360_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_360_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_430.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_430.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_444.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_444.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_458.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_458.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_472.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_472.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_486.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_486.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_500.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_500.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_514.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_514.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_514_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_514_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_528.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_528.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_528_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_528_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_542.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_542.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_556.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_556.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_570.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_570.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_584.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_584.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_598.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_598.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_612.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_612.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_612_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_612_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_626.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_626.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_640.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_640.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_640_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_640_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_654.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_654.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_668.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_668.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_682.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_682.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_696.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_696.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_710.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_710.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_710_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_710_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_724.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_724.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_724_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_724_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_738.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_738.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_738_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_738_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_752.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_752.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_752_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_752_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_Asia_Europe_Africa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_Asia_Europe_Africa.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap_nolegend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_Massachusetts_for_heatmap_nolegend.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_USA_France_England_Brazil_Australia_Russia.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_USA_France_England_Brazil_Australia_Russia.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap_nolegend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_brazil_for_heatmap_nolegend.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_england.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_england.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap.png
--------------------------------------------------------------------------------
/paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap_nolegend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/backtesting/backtesting_day_766_early_prediction_england_for_heatmap_nolegend.png
--------------------------------------------------------------------------------
/paper/binding_retained.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/binding_retained.png
--------------------------------------------------------------------------------
/paper/clade_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/clade_distribution.png
--------------------------------------------------------------------------------
/paper/coef_scale_manhattan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_manhattan.png
--------------------------------------------------------------------------------
/paper/coef_scale_pearson.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_pearson.png
--------------------------------------------------------------------------------
/paper/coef_scale_scatter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_scatter.png
--------------------------------------------------------------------------------
/paper/coef_scale_volcano.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/coef_scale_volcano.png
--------------------------------------------------------------------------------
/paper/convergence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/convergence.png
--------------------------------------------------------------------------------
/paper/convergence_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/convergence_loss.png
--------------------------------------------------------------------------------
/paper/convergent_evolution.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/convergent_evolution.jpg
--------------------------------------------------------------------------------
/paper/deep_scanning.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/deep_scanning.png
--------------------------------------------------------------------------------
/paper/delta_logR_breakdown.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/delta_logR_breakdown.png
--------------------------------------------------------------------------------
/paper/dobanno_2021_immunogenicity_N.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/dobanno_2021_immunogenicity_N.png
--------------------------------------------------------------------------------
/paper/forecast.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/forecast.png
--------------------------------------------------------------------------------
/paper/formatted/Figure1.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure1.ai
--------------------------------------------------------------------------------
/paper/formatted/Figure1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure1.jpg
--------------------------------------------------------------------------------
/paper/formatted/Figure2.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2.ai
--------------------------------------------------------------------------------
/paper/formatted/Figure2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2.jpg
--------------------------------------------------------------------------------
/paper/formatted/Figure2_alt.ai:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2_alt.ai
--------------------------------------------------------------------------------
/paper/formatted/Figure2_alt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/formatted/Figure2_alt.jpg
--------------------------------------------------------------------------------
/paper/gene_ratios.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/gene_ratios.png
--------------------------------------------------------------------------------
/paper/ind_emergences_violin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/ind_emergences_violin.png
--------------------------------------------------------------------------------
/paper/lineage_agreement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/lineage_agreement.png
--------------------------------------------------------------------------------
/paper/lineage_heterogeneity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/lineage_heterogeneity.png
--------------------------------------------------------------------------------
/paper/lineage_prediction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/lineage_prediction.png
--------------------------------------------------------------------------------
/paper/main.bib:
--------------------------------------------------------------------------------
1 | @article{bingham2019pyro,
2 | title={Pyro: Deep universal probabilistic programming},
3 | author={Bingham, Eli and Chen, Jonathan P and Jankowiak, Martin and Obermeyer, Fritz and Pradhan, Neeraj and Karaletsos, Theofanis and Singh, Rohit and Szerlip, Paul and Horsfall, Paul and Goodman, Noah D},
4 | journal={The Journal of Machine Learning Research},
5 | volume={20},
6 | number={1},
7 | pages={973--978},
8 | year={2019},
9 | publisher={JMLR. org},
10 | }
11 |
12 | @misc{aksamentov2020nextclade,
13 | title={Nextclade},
14 | author={Aksamentov, Ivan and Neher, Richard},
15 | url={https://github.com/nextstrain/nextclade},
16 | }
17 |
18 | @inproceedings{gorinova2020automatic,
19 | title={Automatic reparameterisation of probabilistic programs},
20 | author={Gorinova, Maria and Moore, Dave and Hoffman, Matthew},
21 | booktitle={International Conference on Machine Learning},
22 | pages={3648--3657},
23 | year={2020},
24 | organization={PMLR},
25 | }
26 |
27 | @inproceedings{paszke2017automatic,
28 | title={Automatic differentiation in PyTorch},
29 | author={Paszke, Adam and Gross, Sam and Chintala, Soumith and Chanan, Gregory and Yang, Edward and DeVito, Zachary and Lin, Zeming and Desmaison, Alban and Antiga, Luca and Lerer, Adam},
30 | booktitle={NIPS-W},
31 | year={2017},
32 | }
33 |
34 | @inproceedings{elbe2017gisaid,
35 | author={Elbe, S. and Buckland-Merrett, G.},
36 | year={2017},
37 | title={Data, disease and diplomacy: GISAID’s innovative contribution to global health},
38 | booktitle={Global Challenges},
39 | section={1},
40 | pages={33-46},
41 | doi={DOI:10.1002/gch2.1018},
42 | pmcid={PMCID: 31565258},
43 | }
44 |
45 | @inproceedings{rambaut2020dynamic,
46 | title={A dynamic nomenclature proposal for SARS-CoV-2 lineages to assist genomic epidemiology},
47 | author={Rambaut, A. and Holmes, E.C. and O’Toole, Á and Hill, V. and McCrone, J. T. and Ruis, C. and du Plessis, L. and Pybus, O. G.},
48 | year={2020},
49 | booktitle={Nature Microbiology},
50 | doi={DOI:10.1038/s41564-020-0770-5},
51 | }
52 |
53 | @inproceedings{neal2003slice,
54 | author={Neal, Radford M.},
55 | year={2003},
56 | title={Slice Sampling},
57 | booktitle={Annals of Statistics},
58 | volume={31},
59 | number={3},
60 | pages={705–67},
61 | }
62 |
--------------------------------------------------------------------------------
/paper/main.md:
--------------------------------------------------------------------------------
1 | # Analysis of 1,000,000 virus genomes reveals determinants of relative growth rate
2 |
3 | Fritz Obermeyer, Jacob Lemieux, Stephen Schaffner, Daniel Park
4 |
5 | ## Abstract
6 |
7 | We fit a large logistic growth regression model to sequenced genomes of SARS-CoV-2 samples distributed in space (globally) and time (late 2019 -- early 2021).
8 | This model attributes relative changes in transmissibility to each observed mutation, aggregated over >800 PANGO lineages.
9 | Results show that one region of the N gene is particularly important in determining transmissibility, followed by the S gene and the ORF3a gene.
10 |
11 | ## Discussion
12 |
13 | Note that many of the most transmissibility-influencing mutation occur in the N gene and particularly at positions 28800--29000 (see Figure TODO).
14 |
15 | ## Materials and methods
16 |
17 | ### Model
18 |
19 | We model the strain distribution over time as a softmax of a set of linear growth functions, one per strain; this is the multivariate generalization of the standard logistic growth model.
20 | Further, we factorize the growth rate of each strain as a sum over growth rates of the mutations in that strain:
21 | ```
22 | strain_portion = softmax(strain_init + mutation_rate @ mutations_per_strain)
23 | ```
24 | Our hierarchical probabilistic model includes latent variables for:
25 | - global feature regularization parameter (how sparse are the features?).
26 | - initial strain prevalence in each (region, strain).
27 | - rate coefficients of each mutation, with a Laplace prior.
28 | - global observation concentration (how noisy are the observations?).
29 |
30 | The model uses a Dirichlet-multinomial likelihood (the multivariate
31 | generalization of negative binomial) with learned shared concentration parameter.
32 |
33 | TODO add graphical model figure
34 |
35 | This model is robust to a number of sources of bias:
36 | - Sampling bias across regions (it's fine for one region to sample 100x more than another)
37 | - Sampling bias over time (it's fine to change sampling rate over time)
38 | - Change in absolute growth rate of all strains, in any (region, time) cell (i.e. the model is robust to changes in local policies or weather, as long as those changes equally affect all strains).
39 |
40 | However the model is susceptible to the following sources of bias:
41 | - Biased sampling in any (region,time) cell (e.g. sequencing only in case of S-gene target failure).
42 | - Changes in sampling bias within a single region over time (e.g. a country has a lab in only one city, then spins up a second lab in another distant city with different strain portions).
43 |
44 | TODO We considered the possibility of biased submission to the GISAID database, and analyzed the CDC NS3 dataset, finding similar results.
45 |
46 | ## Inference
47 |
48 | We use a combination of MAP estimation and stochastic variational inference.
49 | We MAP estimate the global parameters and initial strain prevalences.
50 | For the remaining mutation-transmissibility latent variable we would like to estimate statistical significance, so we fit a mean-field variational distribution: independent normal distributions with learned means and variances.
51 | While the mean field assumption leads to underdispersion on an absolute scale, it nevertheless allows us to rank the _relative statistical significance_ of different mutations.
52 | This relative ranking is sufficient for the purposes of this paper: to focus attention on mutations and regions of the SARS-CoV-2 genome that appear to be linked to increased transmissibility.
53 |
54 | Inverence is performed via the Pyro probabilistic programming language.
55 | Source code is available at [github.com/broadinstitute/pyro-cov](https://github.com/broadinstitute/pyro-cov).
56 |
57 | ## References
58 |
59 | - Eli Bingham, Jonathan P. Chen, Martin Jankowiak, Fritz Obermeyer, Neeraj Pradhan,
60 | Theofanis Karaletsos, Rohit Singh, Paul Szerlip, Paul Horsfall, Noah D. Goodman
61 | (2018)
62 | "Pyro: Deep Universal Probabilistic Programming"
63 | https://arxiv.org/abs/1810.09538
64 |
65 | ## Tables
66 |
67 | 1. [Top Mutations](top_mutations.md)
68 | 2. [Top Strains](top_strains.md)
69 |
70 | ## Figures
71 |
72 | [](manhattan.pdf)
73 |
74 | [](volcano.pdf)
75 |
--------------------------------------------------------------------------------
/paper/manhattan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan.png
--------------------------------------------------------------------------------
/paper/manhattan_N.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_N.png
--------------------------------------------------------------------------------
/paper/manhattan_N_coef_scale_0.05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_N_coef_scale_0.05.png
--------------------------------------------------------------------------------
/paper/manhattan_ORF1a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1a.png
--------------------------------------------------------------------------------
/paper/manhattan_ORF1a_coef_scale_0.05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1a_coef_scale_0.05.png
--------------------------------------------------------------------------------
/paper/manhattan_ORF1b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1b.png
--------------------------------------------------------------------------------
/paper/manhattan_ORF1b_coef_scale_0.05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF1b_coef_scale_0.05.png
--------------------------------------------------------------------------------
/paper/manhattan_ORF3a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF3a.png
--------------------------------------------------------------------------------
/paper/manhattan_ORF3a_coef_scale_0.05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_ORF3a_coef_scale_0.05.png
--------------------------------------------------------------------------------
/paper/manhattan_S.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_S.png
--------------------------------------------------------------------------------
/paper/manhattan_S_coef_scale_0.05.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/manhattan_S_coef_scale_0.05.png
--------------------------------------------------------------------------------
/paper/moran.csv:
--------------------------------------------------------------------------------
1 | ,NumMutations,GeneSize,PValue,Lengthscale
2 | EntireGenome,2904.0,29394.0,1e-06,100.0
3 | EntireGenome,2904.0,29394.0,1e-06,500.0
4 | S,415.0,3786.0,0.00191,50.0
5 | N,220.0,1251.0,0.017627,50.0
6 | ORF7a,75.0,360.0,0.024066,18.0
7 | ORF3a,198.0,789.0,0.024307,39.45
8 | ORF1a,1107.0,13182.0,0.02971,50.0
9 | ORF7b,26.0,126.0,0.089589,6.3
10 | ORF14,69.0,213.0,0.112527,10.65
11 | ORF6,19.0,177.0,0.138634,8.85
12 | ORF1b,552.0,8052.0,0.329416,50.0
13 | E,17.0,195.0,0.455606,9.75
14 | M,42.0,639.0,0.518497,31.95
15 | ORF8,91.0,357.0,0.810603,17.85
16 | ORF10,18.0,102.0,0.864965,5.1
17 | ORF9b,55.0,240.0,0.927562,12.0
18 |
--------------------------------------------------------------------------------
/paper/multinomial_LR_vs_pyr0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/multinomial_LR_vs_pyr0.jpg
--------------------------------------------------------------------------------
/paper/mutation_agreement.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_agreement.png
--------------------------------------------------------------------------------
/paper/mutation_europe_boxplot_rankby_s.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_europe_boxplot_rankby_s.png
--------------------------------------------------------------------------------
/paper/mutation_europe_boxplot_rankby_t.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_europe_boxplot_rankby_t.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/by_gene_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/by_gene_heatmap.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_individual_mutation_significance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_individual_mutation_significance_M.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance_M.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_individual_mutation_significance_S.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance_S.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_individual_mutation_significance_S__K_to_N.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_individual_mutation_significance_S__K_to_N.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_k_to_n_mutation_significance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_k_to_n_mutation_significance.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_per_gene_aggregate_sign.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_per_gene_aggregate_sign.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_E.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_E.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_N.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_N.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF10.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF1a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF1a.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF1b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF1b.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF3a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF3a.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF6.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF7b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF7b.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF8.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_ORF9b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_ORF9b.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/plot_pval_vs_top_genes_S.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/plot_pval_vs_top_genes_S.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/pvals_vs_top_genes.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/pvals_vs_top_genes.png
--------------------------------------------------------------------------------
/paper/mutation_scoring/top_substitutions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_scoring/top_substitutions.png
--------------------------------------------------------------------------------
/paper/mutation_summaries.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/mutation_summaries.jpg
--------------------------------------------------------------------------------
/paper/region_distribution.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/region_distribution.png
--------------------------------------------------------------------------------
/paper/relative_growth_rate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/relative_growth_rate.jpg
--------------------------------------------------------------------------------
/paper/schematic_overview.key:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/schematic_overview.key
--------------------------------------------------------------------------------
/paper/schematic_overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/schematic_overview.png
--------------------------------------------------------------------------------
/paper/scibib.bib:
--------------------------------------------------------------------------------
1 | % scibib.bib
2 |
3 | % This is the .bib file used to compile the document "A simple Science
4 | % template" (scifile.tex). It is not intended as an example of how to
5 | % set up your BibTeX file.
6 |
7 |
8 |
9 |
10 | @misc{tth, note = "The package is TTH, available at
11 | http://hutchinson.belmont.ma.us/tth/ ."}
12 |
13 | @misc{use2e, note = "As the mark-up of the \TeX\ source for this
14 | document makes clear, your file should be coded in \LaTeX
15 | 2${\varepsilon}$, not \LaTeX\ 2.09 or an earlier release. Also,
16 | please use the \texttt{article} document class."}
17 |
18 | @misc{inclme, note="Among whom are the author of this document. The
19 | ``real'' references and notes contained herein were compiled using
20 | B{\small{IB}}\TeX\ from the sample .bib file \texttt{scibib.bib}, the style
21 | package \texttt{scicite.sty}, and the bibliography style file
22 | \texttt{Science.bst}."}
23 |
24 |
25 | @misc{nattex, note="One of the equation editors we use, Equation Magic
26 | (MicroPress Inc., Forest Hills, NY; http://www.micropress-inc.com/),
27 | interprets native \TeX\ source code and generates an equation as an
28 | OLE picture object that can then be cut and pasted directly into Word.
29 | This editor, however, does not handle \LaTeX\ environments (such as
30 | \texttt{\{array\}} or \texttt{\{eqnarray\}}); it can interpret only
31 | \TeX\ codes. Thus, when there's a choice, we ask that you avoid these
32 | \LaTeX\ calls in displayed math --- for example, that you use the
33 | \TeX\ \verb+\matrix+ command for ordinary matrices, rather than the
34 | \LaTeX\ \texttt{\{array\}} environment."}
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/paper/spectrum_transmissibility.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/spectrum_transmissibility.jpg
--------------------------------------------------------------------------------
/paper/strain_emergence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/strain_emergence.png
--------------------------------------------------------------------------------
/paper/strain_europe_boxplot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/strain_europe_boxplot.png
--------------------------------------------------------------------------------
/paper/strain_prevalence.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/strain_prevalence.png
--------------------------------------------------------------------------------
/paper/supplement_table_3_02_05_2022.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/supplement_table_3_02_05_2022.jpeg
--------------------------------------------------------------------------------
/paper/table2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table2.jpeg
--------------------------------------------------------------------------------
/paper/table_1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table_1.jpeg
--------------------------------------------------------------------------------
/paper/table_1.tsv:
--------------------------------------------------------------------------------
1 | Rank Gene Substitution "Fold Increase
2 | in Fitness" "Number of
3 | Lineages"
4 | 0 1 S R346K 1.098 29
5 | 1 2 S L452Q 1.092 9
6 | 2 3 S V213G 1.057 158
7 | 3 4 S S704L 1.080 9
8 | 4 5 S T19I 1.053 158
9 | 5 6 ORF1b R1315C 1.047 159
10 | 6 7 ORF1b T2163I 1.056 159
11 | 7 8 ORF3a T223I 1.052 178
12 | 8 9 ORF1a T3090I 1.042 159
13 | 9 10 M D3N 1.127 76
14 | 10 11 ORF10 L37F 1.067 22
15 | 11 12 N S413R 1.049 159
16 | 12 13 ORF9b D16G 1.091 36
17 | 13 14 S T376A 1.040 158
18 | 14 15 ORF1a L3027F 1.043 159
19 | 15 16 ORF1a T842I 1.036 169
20 | 16 17 S D796Y 1.063 174
21 | 17 18 S L452R 1.177 367
22 | 18 19 S S371F 1.039 158
23 | 19 20 S Q954H 1.066 211
24 |
--------------------------------------------------------------------------------
/paper/table_1_02_05_2022.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table_1_02_05_2022.jpeg
--------------------------------------------------------------------------------
/paper/table_2_02_05_2022.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/table_2_02_05_2022.jpeg
--------------------------------------------------------------------------------
/paper/vary_gene_elbo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_gene_elbo.png
--------------------------------------------------------------------------------
/paper/vary_gene_likelihood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_gene_likelihood.png
--------------------------------------------------------------------------------
/paper/vary_gene_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_gene_loss.png
--------------------------------------------------------------------------------
/paper/vary_nsp_elbo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_nsp_elbo.png
--------------------------------------------------------------------------------
/paper/vary_nsp_likelihood.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/vary_nsp_likelihood.png
--------------------------------------------------------------------------------
/paper/volcano.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/paper/volcano.png
--------------------------------------------------------------------------------
/pyrocov/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | __version__ = "0.1.0"
5 |
--------------------------------------------------------------------------------
/pyrocov/aa.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | DNA_TO_AA = {
5 | "TTT": "F",
6 | "TTC": "F",
7 | "TTA": "L",
8 | "TTG": "L",
9 | "CTT": "L",
10 | "CTC": "L",
11 | "CTA": "L",
12 | "CTG": "L",
13 | "ATT": "I",
14 | "ATC": "I",
15 | "ATA": "I",
16 | "ATG": "M",
17 | "GTT": "V",
18 | "GTC": "V",
19 | "GTA": "V",
20 | "GTG": "V",
21 | "TCT": "S",
22 | "TCC": "S",
23 | "TCA": "S",
24 | "TCG": "S",
25 | "CCT": "P",
26 | "CCC": "P",
27 | "CCA": "P",
28 | "CCG": "P",
29 | "ACT": "T",
30 | "ACC": "T",
31 | "ACA": "T",
32 | "ACG": "T",
33 | "GCT": "A",
34 | "GCC": "A",
35 | "GCA": "A",
36 | "GCG": "A",
37 | "TAT": "Y",
38 | "TAC": "Y",
39 | "TAA": None, # stop
40 | "TAG": None, # stop
41 | "CAT": "H",
42 | "CAC": "H",
43 | "CAA": "Q",
44 | "CAG": "Q",
45 | "AAT": "N",
46 | "AAC": "N",
47 | "AAA": "K",
48 | "AAG": "K",
49 | "GAT": "D",
50 | "GAC": "D",
51 | "GAA": "E",
52 | "GAG": "E",
53 | "TGT": "C",
54 | "TGC": "C",
55 | "TGA": None, # stop
56 | "TGG": "W",
57 | "CGT": "R",
58 | "CGC": "R",
59 | "CGA": "R",
60 | "CGG": "R",
61 | "AGT": "S",
62 | "AGC": "S",
63 | "AGA": "R",
64 | "AGG": "R",
65 | "GGT": "G",
66 | "GGC": "G",
67 | "GGA": "G",
68 | "GGG": "G",
69 | }
70 |
--------------------------------------------------------------------------------
/pyrocov/align.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import os
5 |
6 | # Source: https://samtools.github.io/hts-specs/SAMv1.pdf
7 | CIGAR_CODES = "MIDNSHP=X" # Note minimap2 uses only "MIDNSH"
8 |
9 | ROOT = os.path.dirname(os.path.dirname(__file__))
10 | NEXTCLADE_DATA = os.path.expanduser("~/github/nextstrain/nextclade/data/sars-cov-2")
11 | PANGOLEARN_DATA = os.path.expanduser("~/github/cov-lineages/pangoLEARN/pangoLEARN/data")
12 |
--------------------------------------------------------------------------------
/pyrocov/distributions.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import math
5 |
6 | import torch
7 | from pyro.distributions import TorchDistribution
8 | from torch.distributions import constraints
9 | from torch.distributions.utils import broadcast_all
10 |
11 |
12 | class SoftLaplace(TorchDistribution):
13 | """
14 | Smooth distribution with Laplace-like tail behavior.
15 | """
16 |
17 | arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
18 | support = constraints.real
19 | has_rsample = True
20 |
21 | def __init__(self, loc, scale, *, validate_args=None):
22 | self.loc, self.scale = broadcast_all(loc, scale)
23 | super().__init__(self.loc.shape, validate_args=validate_args)
24 |
25 | def expand(self, batch_shape, _instance=None):
26 | new = self._get_checked_instance(SoftLaplace, _instance)
27 | batch_shape = torch.Size(batch_shape)
28 | new.loc = self.loc.expand(batch_shape)
29 | new.scale = self.scale.expand(batch_shape)
30 | super(SoftLaplace, new).__init__(batch_shape, validate_args=False)
31 | new._validate_args = self._validate_args
32 | return new
33 |
34 | def log_prob(self, value):
35 | if self._validate_args:
36 | self._validate_sample(value)
37 | z = (value - self.loc) / self.scale
38 | return math.log(2 / math.pi) - self.scale.log() - torch.logaddexp(z, -z)
39 |
40 | def rsample(self, sample_shape=torch.Size()):
41 | shape = self._extended_shape(sample_shape)
42 | u = self.loc.new_empty(shape).uniform_()
43 | return self.icdf(u)
44 |
45 | def cdf(self, value):
46 | if self._validate_args:
47 | self._validate_sample(value)
48 | z = (value - self.loc) / self.scale
49 | return z.exp().atan().mul(2 / math.pi)
50 |
51 | def icdf(self, value):
52 | return value.mul(math.pi / 2).tan().log().mul(self.scale).add(self.loc)
53 |
--------------------------------------------------------------------------------
/pyrocov/external/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/pyrocov/external/__init__.py
--------------------------------------------------------------------------------
/pyrocov/external/usher/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Yatish Turakhia (https://turakhia.eng.ucsd.edu/), Haussler Lab (https://hausslergenomics.ucsc.edu/), Corbett-Detig Lab (https://corbett.ucsc.edu/)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/pyrocov/external/usher/README.md:
--------------------------------------------------------------------------------
1 | This directory contains code copied from https://github.com/yatisht/usher
2 |
--------------------------------------------------------------------------------
/pyrocov/external/usher/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/broadinstitute/pyro-cov/9f84acc3ddff9bcb55c8d1b77fd23204a4f54b8e/pyrocov/external/usher/__init__.py
--------------------------------------------------------------------------------
/pyrocov/fasta.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 |
5 | class ShardedFastaWriter:
6 | """
7 | Writer that splits into multiple fasta files to avoid nextclade file size
8 | limit.
9 | """
10 |
11 | def __init__(self, filepattern, max_count=5000):
12 | assert filepattern.count("*") == 1
13 | self.filepattern = filepattern
14 | self.max_count = max_count
15 | self._file_count = 0
16 | self._line_count = 0
17 | self._file = None
18 |
19 | def _open(self):
20 | filename = self.filepattern.replace("*", str(self._file_count))
21 | print(f"writing to {filename}")
22 | return open(filename, "wt")
23 |
24 | def __enter__(self):
25 | assert self._file is None
26 | self._file = self._open()
27 | self._file_count += 1
28 | return self
29 |
30 | def __exit__(self, *args, **kwargs):
31 | self._file.close()
32 | self._file = None
33 | self._file_count = 0
34 | self._line_count = 0
35 |
36 | def write(self, name, sequence):
37 | if self._line_count == self.max_count:
38 | self._file.close()
39 | self._file = self._open()
40 | self._file_count += 1
41 | self._line_count = 0
42 | self._file.write(">")
43 | self._file.write(name)
44 | self._file.write("\n")
45 | self._file.write(sequence)
46 | self._file.write("\n")
47 | self._line_count += 1
48 |
--------------------------------------------------------------------------------
/pyrocov/hashsubset.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import hashlib
5 | import heapq
6 |
7 |
8 | class RandomSubDict:
9 | def __init__(self, max_size):
10 | assert isinstance(max_size, int) and max_size > 0
11 | self.key_to_value = {}
12 | self.hash_to_key = {}
13 | self.heap = []
14 | self.max_size = max_size
15 |
16 | def __len__(self):
17 | return len(self.heap)
18 |
19 | def __setitem__(self, key, value):
20 | assert key not in self.key_to_value
21 |
22 | # Add (key,value) pair.
23 | self.key_to_value[key] = value
24 | h = hashlib.sha1(key.encode("utf-8")).hexdigest()
25 | self.hash_to_key[h] = key
26 | heapq.heappush(self.heap, h)
27 |
28 | # Truncate via k-min-hash.
29 | if len(self.heap) > self.max_size:
30 | h = heapq.heappop(self.heap)
31 | key = self.hash_to_key.pop(h)
32 | self.key_to_value.pop(key)
33 |
34 | def keys(self):
35 | assert len(self.key_to_value) <= self.max_size
36 | return self.key_to_value.keys()
37 |
38 | def values(self):
39 | assert len(self.key_to_value) <= self.max_size
40 | return self.key_to_value.values()
41 |
42 | def items(self):
43 | assert len(self.key_to_value) <= self.max_size
44 | return self.key_to_value.items()
45 |
--------------------------------------------------------------------------------
/pyrocov/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import weakref
5 | from typing import Dict
6 |
7 | import torch
8 |
9 |
10 | def logistic_logsumexp(alpha, beta, delta, tau, *, backend="sequential"):
11 | """
12 | Computes::
13 |
14 | (alpha + beta * (delta + tau[:, None])).logsumexp(-1)
15 |
16 | where::
17 |
18 | alpha.shape == [P, S]
19 | beta.shape == [P, S]
20 | delta.shape == [P, S]
21 | tau.shape == [T, P]
22 |
23 | :param str backend: One of "naive", "sequential".
24 | """
25 | assert alpha.dim() == 2
26 | assert alpha.shape == beta.shape == delta.shape
27 | assert tau.dim() == 2
28 | assert tau.size(1) == alpha.size(0)
29 | assert not tau.requires_grad
30 |
31 | if backend == "naive":
32 | return (alpha + beta * (delta + tau[:, :, None])).logsumexp(-1)
33 | if backend == "sequential":
34 | return LogisticLogsumexp.apply(alpha, beta, delta, tau)
35 | raise ValueError(f"Unknown backend: {repr(backend)}")
36 |
37 |
38 | class LogisticLogsumexp(torch.autograd.Function):
39 | @staticmethod
40 | def forward(ctx, alpha, beta, delta, tau):
41 | P, S = alpha.shape
42 | T, P = tau.shape
43 | output = alpha.new_zeros(T, P)
44 | for t in range(len(tau)):
45 | logits = (delta + tau[t, :, None]).mul_(beta).add_(alpha) # [P, S]
46 | output[t] = logits.logsumexp(-1) # [P]
47 |
48 | ctx.save_for_backward(alpha, beta, delta, tau, output)
49 | return output # [T, P]
50 |
51 | @staticmethod
52 | def backward(ctx, grad_output):
53 | alpha, beta, delta, tau, output = ctx.saved_tensors
54 |
55 | grad_alpha = torch.zeros_like(alpha) # [P, S]
56 | grad_beta = torch.zeros_like(beta) # [P, S]
57 | for t in range(len(tau)):
58 | delta_tau = delta + tau[t, :, None] # [P, S]
59 | logits = (delta_tau * beta).add_(alpha) # [ P, S]
60 | softmax_logits = logits.sub_(output[t, :, None]).exp_() # [P, S]
61 | grad_logits = softmax_logits * grad_output[t, :, None] # [P, S]
62 | grad_alpha += grad_logits # [P, S]
63 | grad_beta += delta_tau * grad_logits # [P, S]
64 |
65 | grad_delta = beta * grad_alpha # [P, S]
66 | return grad_alpha, grad_beta, grad_delta, None
67 |
68 |
69 | _log_factorial_cache: Dict[int, torch.Tensor] = {}
70 |
71 |
72 | def log_factorial_sum(x: torch.Tensor) -> torch.Tensor:
73 | if x.requires_grad:
74 | return (x + 1).lgamma().sum()
75 | key = id(x)
76 | if key not in _log_factorial_cache:
77 | weakref.finalize(x, _log_factorial_cache.pop, key, None) # type: ignore
78 | _log_factorial_cache[key] = (x + 1).lgamma().sum()
79 | return _log_factorial_cache[key]
80 |
81 |
82 | def sparse_poisson_likelihood(full_log_rate, nonzero_log_rate, nonzero_value):
83 | """
84 | The following are equivalent::
85 |
86 | # Version 1. dense
87 | log_prob = Poisson(log_rate.exp()).log_prob(value).sum()
88 |
89 | # Version 2. sparse
90 | nnz = value.nonzero(as_tuple=True)
91 | log_prob = sparse_poisson_likelihood(
92 | log_rate.logsumexp(-1),
93 | log_rate[nnz],
94 | value[nnz],
95 | )
96 | """
97 | # Let p = Poisson(log_rate.exp()). Then
98 | # p.log_prob(value)
99 | # = log_rate * value - log_rate.exp() - (value + 1).lgamma()
100 | # p.log_prob(0) = -log_rate.exp()
101 | # p.log_prob(value) - p.log_prob(0)
102 | # = log_rate * value - log_rate.exp() - (value + 1).lgamma() + log_rate.exp()
103 | # = log_rate * value - (value + 1).lgamma()
104 | return (
105 | torch.dot(nonzero_log_rate, nonzero_value)
106 | - log_factorial_sum(nonzero_value)
107 | - full_log_rate.exp().sum()
108 | )
109 |
110 |
111 | def sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value):
112 | """
113 | The following are equivalent::
114 |
115 | # Version 1. dense
116 | log_prob = Multinomial(logits=logits).log_prob(value).sum()
117 |
118 | # Version 2. sparse
119 | nnz = value.nonzero(as_tuple=True)
120 | log_prob = sparse_multinomial_likelihood(
121 | value.sum(-1),
122 | (logits - logits.logsumexp(-1))[nnz],
123 | value[nnz],
124 | )
125 | """
126 | return (
127 | log_factorial_sum(total_count)
128 | - log_factorial_sum(nonzero_value)
129 | + torch.dot(nonzero_logits, nonzero_value)
130 | )
131 |
132 |
133 | def sparse_categorical_kl(log_q, p_support, log_p):
134 | """
135 | Computes the restricted Kl divergence::
136 |
137 | sum_i restrict(q)(i) (log q(i) - log p(i))
138 |
139 | where ``p`` is a uniform prior, ``q`` is the posterior, and
140 | ``restrict(q))`` is the posterior restricted to the support of ``p`` and
141 | renormalized. Note for degenerate ``p=delta(i)`` this reduces to the log
142 | likelihood ``log q(i)``.
143 | """
144 | assert log_q.dim() == 1
145 | assert log_p.dim() == 1
146 | assert p_support.shape == log_p.shape + log_q.shape
147 | q = log_q.exp()
148 | sum_q = torch.mv(p_support, q)
149 | sum_q_log_q = torch.mv(p_support, q * log_q)
150 | sum_r_log_q = sum_q_log_q / sum_q # restrict and normalize
151 | kl = sum_r_log_q - log_p # note sum_r_log_p = log_p because p is uniform
152 | return kl.sum()
153 |
--------------------------------------------------------------------------------
/pyrocov/plotting.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import torch
5 |
6 |
7 | @torch.no_grad()
8 | def force_apart(*X, radius=[0.05, 0.005], iters=10, stepsize=2, xshift=0.01):
9 | X = torch.stack([torch.as_tensor(x) for x in X], dim=-1)
10 | assert len(X.shape) == 2
11 | radius = torch.as_tensor(radius)
12 | scale = X.max(0).values - X.min(0).values
13 | X /= scale
14 | for _ in range(iters):
15 | XX = X - X[:, None]
16 | r = (XX / radius).square().sum(-1, True)
17 | kernel = r.neg().exp()
18 | F = (XX * radius.square().sum() / radius**2 * kernel).sum(0)
19 | F_norm = F.square().sum(-1, True).sqrt().clamp(min=1e-20)
20 | F *= F_norm.neg().expm1().neg() / F_norm
21 | F *= stepsize
22 | if xshift is not None:
23 | F[:, 0].clamp_(min=0)
24 | X += F / iters
25 | if xshift is not None:
26 | X[:, 0] += xshift
27 | X *= scale
28 | return X.unbind(-1)
29 |
--------------------------------------------------------------------------------
/pyrocov/sarscov2.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import os
5 | import re
6 | from collections import OrderedDict, defaultdict
7 | from typing import Dict, List, Tuple
8 |
9 | from .aa import DNA_TO_AA
10 | from .align import NEXTCLADE_DATA
11 |
12 | REFERENCE_SEQ = None # loaded lazily
13 |
14 | # Adapted from https://github.com/nextstrain/ncov/blob/50ceffa/defaults/annotation.gff
15 | # Note these are 1-based positions
16 | annotation_tsv = """\
17 | seqname source feature start end score strand frame attribute
18 | . . gene 26245 26472 . + . gene_name "E"
19 | . . gene 26523 27191 . + . gene_name "M"
20 | . . gene 28274 29533 . + . gene_name "N"
21 | . . gene 29558 29674 . + . gene_name "ORF10"
22 | . . gene 28734 28955 . + . gene_name "ORF14"
23 | . . gene 266 13468 . + . gene_name "ORF1a"
24 | . . gene 13468 21555 . + . gene_name "ORF1b"
25 | . . gene 25393 26220 . + . gene_name "ORF3a"
26 | . . gene 27202 27387 . + . gene_name "ORF6"
27 | . . gene 27394 27759 . + . gene_name "ORF7a"
28 | . . gene 27756 27887 . + . gene_name "ORF7b"
29 | . . gene 27894 28259 . + . gene_name "ORF8"
30 | . . gene 28284 28577 . + . gene_name "ORF9b"
31 | . . gene 21563 25384 . + . gene_name "S"
32 | """
33 |
34 |
35 | def _():
36 | genes = []
37 | rows = annotation_tsv.split("\n")
38 | header, rows = rows[0].split("\t"), rows[1:]
39 | for row in rows:
40 | if row:
41 | row = dict(zip(header, row.split("\t")))
42 | gene_name = row["attribute"].split('"')[1]
43 | start = int(row["start"])
44 | end = int(row["end"])
45 | genes.append(((start, end), gene_name))
46 | genes.sort()
47 | return OrderedDict((gene_name, pos) for pos, gene_name in genes)
48 |
49 |
50 | # This maps gene name to the nucleotide position in the genome,
51 | # as measured in the original Wuhan virus.
52 | GENE_TO_POSITION: Dict[str, Tuple[int, int]] = _()
53 |
54 | # This maps gene name to a set of regions in that gene.
55 | # These regions may be used in plotting e.g. mutrans.ipynb.
56 | # Each region has a string label and an extent (start, end)
57 | # measured in amino acid positions relative to the start.
58 | GENE_STRUCTURE: Dict[str, Dict[str, Tuple[int, int]]] = {
59 | # Source: https://www.nature.com/articles/s41401-020-0485-4/figures/2
60 | "S": {
61 | "NTD": (13, 305),
62 | "RBD": (319, 541),
63 | "FC": (682, 685),
64 | "FP": (788, 806),
65 | "HR1": (912, 984),
66 | "HR2": (1163, 1213),
67 | "TM": (1213, 1237),
68 | "CT": (1237, 1273),
69 | },
70 | # Source https://www.nature.com/articles/s41467-021-21953-3
71 | "N": {
72 | "NTD": (1, 49),
73 | "RNA binding": (50, 174),
74 | "SR": (175, 215),
75 | "dimerization": (246, 365),
76 | "CTD": (365, 419),
77 | # "immunogenic": (133, 217),
78 | },
79 | # Source: https://www.ncbi.nlm.nih.gov/protein/YP_009725295.1
80 | "ORF1a": {
81 | "nsp1": (0, 180), # leader protein
82 | "nsp2": (180, 818),
83 | "nsp3": (818, 2763),
84 | "nsp4": (2763, 3263),
85 | "nsp5": (3263, 3569), # 3C-like proteinase
86 | "nsp6": (3569, 3859),
87 | "nsp7": (3859, 3942),
88 | "nsp8": (3942, 4140),
89 | "nsp9": (4140, 4253),
90 | "nsp10": (4253, 4392),
91 | "nsp11": (4392, 4405),
92 | },
93 | # Source: https://www.ncbi.nlm.nih.gov/protein/1796318597
94 | "ORF1ab": {
95 | "nsp1": (0, 180), # leader protein
96 | "nsp2": (180, 818),
97 | "nsp3": (818, 2763),
98 | "nsp4": (2763, 3263),
99 | "nsp5": (3263, 3569), # 3C-like proteinase
100 | "nsp6": (3569, 3859),
101 | "nsp7": (3859, 3942),
102 | "nsp8": (3942, 4140),
103 | "nsp9": (4140, 4253),
104 | "nsp10": (4253, 4392),
105 | "nsp12": (4392, 5324), # RNA-dependent RNA polymerase
106 | "nsp13": (5324, 5925), # helicase
107 | "nsp14": (5925, 6452), # 3'-to-5' exonuclease
108 | "nsp15": (6452, 6798), # endoRNAse
109 | "nsp16": (6798, 7096), # 2'-O-ribose methyltransferase
110 | },
111 | # Source: see infer_ORF1b_structure() below.
112 | "ORF1b": {
113 | "nsp12": (0, 924), # RNA-dependent RNA polymerase
114 | "nsp13": (924, 1525), # helicase
115 | "nsp14": (1525, 2052), # 3'-to-5' exonuclease
116 | "nsp15": (2052, 2398), # endoRNAse
117 | "nsp16": (2398, 2696), # 2'-O-ribose methyltransferase
118 | },
119 | }
120 |
121 |
122 | def load_gene_structure(filename=None, gene_name=None):
123 | """
124 | Loads structure from a GenPept file for use in GENE_STRUCTURE.
125 | This is used only when updating the static GENE_STRUCTURE dict.
126 | """
127 | from Bio import SeqIO
128 |
129 | if filename is None:
130 | assert gene_name is not None
131 | filename = f"data/{gene_name}.gp"
132 | result = {}
133 | with open(filename) as f:
134 | for record in SeqIO.parse(f, format="genbank"):
135 | for feature in record.features:
136 | if feature.type == "mat_peptide":
137 | product = feature.qualifiers["product"][0]
138 | assert isinstance(product, str)
139 | start = int(feature.location.start)
140 | end = int(feature.location.end)
141 | result[product] = start, end
142 | return result
143 |
144 |
145 | def infer_ORF1b_structure():
146 | """
147 | Infers approximate ORF1b structure from ORF1ab.
148 | This is used only when updating the static GENE_STRUCTURE dict.
149 | """
150 | ORF1a_start = GENE_TO_POSITION["ORF1a"][0]
151 | ORF1b_start = GENE_TO_POSITION["ORF1b"][0]
152 | shift = (ORF1b_start - ORF1a_start) // 3
153 | result = {}
154 | for name, (start, end) in GENE_STRUCTURE["ORF1ab"].items():
155 | start -= shift
156 | end -= shift
157 | if end > 0:
158 | start = max(start, 0)
159 | result[name] = start, end
160 | return result
161 |
162 |
163 | def aa_mutation_to_position(m: str) -> int:
164 | """
165 | E.g. map 'S:N501Y' to 21563 + (501 - 1) * 3 = 23063.
166 | """
167 | gene_name, subs = m.split(":")
168 | start, end = GENE_TO_POSITION[gene_name]
169 | match = re.search(r"\d+", subs)
170 | assert match is not None
171 | aa_offset = int(match.group(0)) - 1
172 | return start + aa_offset * 3
173 |
174 |
175 | def nuc_mutations_to_aa_mutations(ms: List[str]) -> List[str]:
176 | global REFERENCE_SEQ
177 | if REFERENCE_SEQ is None:
178 | REFERENCE_SEQ = load_reference_sequence()
179 |
180 | ms_by_aa = defaultdict(list)
181 |
182 | for m in ms:
183 | # Parse a nucleotide mutation such as "A1234G" -> (1234, "G").
184 | # Note this uses 1-based indexing.
185 | if isinstance(m, str):
186 | position_nuc = int(m[1:-1])
187 | new_nuc = m[-1]
188 | else:
189 | # assert isinstance(m, pyrocov.usher.Mutation)
190 | position_nuc = m.position
191 | new_nuc = m.mut
192 |
193 | # Find the first matching gene.
194 | for gene, (start, end) in GENE_TO_POSITION.items():
195 | if start <= position_nuc <= end:
196 | position_aa = (position_nuc - start) // 3
197 | position_codon = (position_nuc - start) % 3
198 | ms_by_aa[gene, position_aa].append((position_codon, new_nuc))
199 |
200 | # Format cumulative amino acid changes.
201 | result = []
202 | for (gene, position_aa), ms in ms_by_aa.items():
203 | start, end = GENE_TO_POSITION[gene]
204 |
205 | # Apply mutation to determine new aa.
206 | pos = start + position_aa * 3
207 | pos -= 1 # convert from 1-based to 0-based
208 | old_codon = REFERENCE_SEQ[pos : pos + 3]
209 | new_codon = list(old_codon)
210 | for position_codon, new_nuc in ms:
211 | new_codon[position_codon] = new_nuc
212 | new_codon = "".join(new_codon)
213 |
214 | # Format.
215 | old_aa = DNA_TO_AA[old_codon]
216 | new_aa = DNA_TO_AA[new_codon]
217 | if new_aa == old_aa: # ignore synonymous substitutions
218 | continue
219 | if old_aa is None:
220 | old_aa = "STOP"
221 | if new_aa is None:
222 | new_aa = "STOP"
223 | result.append(f"{gene}:{old_aa}{position_aa + 1}{new_aa}") # 1-based
224 | return result
225 |
226 |
227 | def load_reference_sequence():
228 | with open(os.path.join(NEXTCLADE_DATA, "reference.fasta")) as f:
229 | ref = "".join(line.strip() for line in f if not line.startswith(">"))
230 | assert len(ref) == 29903, len(ref)
231 | return ref
232 |
--------------------------------------------------------------------------------
/pyrocov/sketch.cpp:
--------------------------------------------------------------------------------
1 | // Copyright Contributors to the Pyro-Cov project.
2 | // SPDX-License-Identifier: Apache-2.0
3 |
4 | #include
5 | #include
6 | #include
7 |
8 | #include
9 | #include
10 |
11 | inline uint64_t murmur64(uint64_t h) {
12 | h ^= h >> 33;
13 | h *= 0xff51afd7ed558ccd;
14 | h ^= h >> 33;
15 | h *= 0xc4ceb9fe1a85ec53;
16 | h ^= h >> 33;
17 | return h;
18 | }
19 |
20 | at::Tensor get_32mers(const std::string& seq) {
21 | static std::vector to_bits(256, false);
22 | to_bits['A'] = 0;
23 | to_bits['C'] = 1;
24 | to_bits['G'] = 2;
25 | to_bits['T'] = 3;
26 |
27 | int size = seq.size() - 32 + 1;
28 | if (size <= 0) {
29 | return at::empty(0, at::kLong);
30 | }
31 | at::Tensor out = at::empty(size, at::kLong);
32 | int64_t * const out_data = static_cast(out.data_ptr());
33 |
34 | int64_t kmer = 0;
35 | for (int pos = 0; pos < 31; ++pos) {
36 | kmer <<= 2;
37 | kmer ^= to_bits[seq[pos]];
38 | }
39 | for (int pos = 0; pos < size; ++pos) {
40 | kmer <<= 2;
41 | kmer ^= to_bits[seq[pos + 31]];
42 | out_data[pos] = kmer;
43 | }
44 | return out;
45 | }
46 |
47 | struct KmerCounter {
48 | KmerCounter() : to_bits(256, false), counts() {
49 | to_bits['A'] = 0;
50 | to_bits['C'] = 1;
51 | to_bits['G'] = 2;
52 | to_bits['T'] = 3;
53 | }
54 |
55 | void update(const std::string& seq) {
56 | int size = seq.size() - 32 + 1;
57 | if (size <= 0) {
58 | return;
59 | }
60 |
61 | int64_t kmer = 0;
62 | for (int pos = 0; pos < 31; ++pos) {
63 | kmer <<= 2;
64 | kmer ^= to_bits[seq[pos]];
65 | }
66 | for (int pos = 0; pos < size; ++pos) {
67 | kmer <<= 2;
68 | kmer ^= to_bits[seq[pos + 31]];
69 | counts[kmer] += 1;
70 | }
71 | }
72 |
73 | void truncate_below(int64_t threshold) {
74 | for (auto i = counts.begin(); i != counts.end();) {
75 | if (i->second < threshold) {
76 | counts.erase(i++);
77 | } else {
78 | ++i;
79 | }
80 | }
81 | }
82 |
83 | std::unordered_map to_dict() const { return counts; }
84 |
85 | std::vector to_bits;
86 | std::unordered_map counts;
87 | };
88 |
89 | void string_to_soft_hash(int min_k, int max_k, const std::string& seq, at::Tensor out) {
90 | static std::vector to_bits(256, false);
91 | to_bits['A'] = 0;
92 | to_bits['C'] = 1;
93 | to_bits['G'] = 2;
94 | to_bits['T'] = 3;
95 |
96 | static std::vector salts(33);
97 | assert(max_k < salts.size());
98 | for (int k = min_k; k <= max_k; ++k) {
99 | salts[k] = murmur64(1 + k);
100 | }
101 |
102 | float * const data_begin = static_cast(out.data_ptr());
103 | float * const data_end = data_begin + out.size(-1);
104 | for (int pos = 0, end = seq.size(); pos != end; ++pos) {
105 | if (max_k > end - pos) {
106 | max_k = end - pos;
107 | }
108 | uint64_t hash = 0;
109 | for (int k = 1; k <= max_k; ++k) {
110 | int i = k - 1;
111 | hash ^= to_bits[seq[pos + i]] << (i + i);
112 | if (k < min_k) continue;
113 | uint64_t hash_k = murmur64(salts[k] ^ hash);
114 | for (float *p = data_begin; p != data_end; ++p) {
115 | *p += static_cast(hash_k & 1UL) * 2L - 1L;
116 | hash_k >>= 1UL;
117 | }
118 | }
119 | }
120 | }
121 |
122 | void string_to_clock_hash_v0(int k, const std::string& seq, at::Tensor clocks, at::Tensor count) {
123 | static std::vector to_bits(256, false);
124 | to_bits['A'] = 0;
125 | to_bits['C'] = 1;
126 | to_bits['G'] = 2;
127 | to_bits['T'] = 3;
128 |
129 | int8_t * const clocks_data = static_cast(clocks.data_ptr());
130 | const int num_kmers = seq.size() - k + 1;
131 | count.add_(num_kmers);
132 | for (int pos = 0; pos < num_kmers; ++pos) {
133 | uint64_t hash = murmur64(1 + k);
134 | for (int i = 0; i < k; ++i) {
135 | hash ^= to_bits[seq[pos + i]] << (i + i);
136 | }
137 | hash = murmur64(hash);
138 | for (int i = 0; i < 64; ++i) {
139 | clocks_data[i] += (hash >> i) & 1;
140 | }
141 | }
142 | }
143 |
144 | void string_to_clock_hash(int k, const std::string& seq, at::Tensor clocks, at::Tensor count) {
145 | static std::vector to_bits(256, false);
146 | to_bits['A'] = 0;
147 | to_bits['C'] = 1;
148 | to_bits['G'] = 2;
149 | to_bits['T'] = 3;
150 |
151 | const int num_words = clocks.size(-1) / 64;
152 | uint64_t * const clocks_data = static_cast(clocks.data_ptr());
153 | const int num_kmers = seq.size() - k + 1;
154 | count.add_(num_kmers);
155 | for (int pos = 0; pos < num_kmers; ++pos) {
156 | uint64_t hash = 0;
157 | for (int i = 0; i != k; ++i) {
158 | hash ^= to_bits[seq[pos + i]] << (i + i);
159 | }
160 | for (int w = 0; w != num_words; ++w) {
161 | const uint64_t hash_w = murmur64(murmur64(1 + w) ^ hash);
162 | for (int b = 0; b != 8; ++b) {
163 | const int wb = w * 8 + b;
164 | clocks_data[wb] = (clocks_data[wb] + ((hash_w >> b) & 0x0101010101010101UL))
165 | & 0x7F7F7F7F7F7F7F7FUL;
166 | }
167 | }
168 | }
169 | }
170 |
171 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
172 | m.def("get_32mers", &get_32mers, "Extract list of 32-mers from a string");
173 | m.def("string_to_soft_hash", &string_to_soft_hash, "Convert a string to a soft hash");
174 | m.def("string_to_clock_hash", &string_to_clock_hash, "Convert a string to a clock hash");
175 | py::class_(m, "KmerCounter")
176 | .def(py::init<>())
177 | .def("update", &KmerCounter::update)
178 | .def("truncate_below", &KmerCounter::truncate_below)
179 | .def("to_dict", &KmerCounter::to_dict);
180 | }
181 |
--------------------------------------------------------------------------------
/pyrocov/softmax_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from collections import defaultdict
5 |
6 | import pyro.distributions as dist
7 | import torch
8 |
9 | from pyrocov.phylo import Phylogeny
10 |
11 |
12 | class SoftmaxTree(dist.Distribution):
13 | """
14 | Samples a :class:`~pyrocov.phylo.Phylogeny` given parameters of a tree
15 | embedding.
16 |
17 | :param torch.Tensor bit_times: Tensor of times of each bit in the
18 | embedding.
19 | :param torch.Tensor logits: ``(num_leaves, num_bits)``-shaped tensor
20 | parametrizing the independent Bernoulli distributions over bits
21 | in each leaf's embedding.
22 | """
23 |
24 | has_rsample = True # only wrt times, not parents
25 |
26 | def __init__(self, leaf_times, bit_times, logits):
27 | assert leaf_times.dim() == 1
28 | assert bit_times.dim() == 1
29 | assert logits.dim() == 2
30 | assert logits.shape == leaf_times.shape + bit_times.shape
31 | self.leaf_times = leaf_times
32 | self.bit_times = bit_times
33 | self._bernoulli = dist.Bernoulli(logits=logits)
34 | super().__init__()
35 |
36 | @property
37 | def probs(self):
38 | return self._bernoulli.probs
39 |
40 | @property
41 | def logits(self):
42 | return self._bernoulli.logits
43 |
44 | @property
45 | def num_leaves(self):
46 | return self.logits.size(0)
47 |
48 | @property
49 | def num_bits(self):
50 | return self.logits.size(1)
51 |
52 | def entropy(self):
53 | return self._bernoulli.entropy().sum([-1, -2])
54 |
55 | def sample(self, sample_shape=torch.Size()):
56 | if sample_shape:
57 | raise NotImplementedError
58 | raise NotImplementedError("TODO")
59 |
60 | def rsample(self, sample_shape=torch.Size()):
61 | if sample_shape:
62 | raise NotImplementedError
63 | bits = self._bernoulli.sample()
64 | num_leaves, num_bits = bits.shape
65 | phylogeny = _decode(self.leaf_times, self.bit_times, bits, self.probs)
66 | return phylogeny
67 |
68 | def log_prob(self, phylogeny):
69 | """
70 | :param ~pyrocov.phylo.Phylogeny phylogeny:
71 | """
72 | return self.entropy() # Is this right?
73 |
74 |
75 | # TODO Implement a C++ version.
76 | # This costs O(num_bits * num_leaves) sequential time.
77 | def _decode(leaf_times, bit_times, bits, probs):
78 | # Sort bits by time.
79 | bit_times, index = bit_times.sort()
80 | bits = bits[..., index]
81 | probs = probs[..., index]
82 |
83 | # Construct internal nodes.
84 | num_leaves, num_bits = bits.shape
85 | assert num_leaves >= 2
86 | times = torch.cat([leaf_times, leaf_times.new_empty(num_leaves - 1)])
87 | parents = torch.empty(2 * num_leaves - 1, dtype=torch.long)
88 | leaves = torch.arange(num_leaves)
89 |
90 | next_id = num_leaves
91 |
92 | def get_id():
93 | nonlocal next_id
94 | next_id += 1
95 | return next_id - 1
96 |
97 | root = get_id()
98 | parents[root] = -1
99 | partitions = [{frozenset(range(num_leaves)): root}]
100 | for t, b in zip(*bit_times.sort()):
101 | partitions.append({})
102 | for partition, p in partitions[-2].items():
103 | children = defaultdict(set)
104 | for n in partition:
105 | bit = bits[n, b].item()
106 | # TODO Clamp bit if t is later than node n.
107 | children[bit].add(n)
108 | if len(children) == 1:
109 | partitions[-1][partition] = p
110 | continue
111 | assert len(children) == 2
112 | for child in children.values():
113 | if len(child) == 1:
114 | # Terminate at a leaf.
115 | c = child.pop()
116 | else:
117 | # Create a new internal node.
118 | c = get_id()
119 | partitions[-1][frozenset(child)] = c
120 | parents[c] = p
121 | times[p] = t
122 | # Create binarized fans for remaining leaves.
123 | for partition, p in partitions[-1].items():
124 | t = times[torch.tensor(list(partition))].min()
125 | times[p] = t
126 | partition = set(partition)
127 | while len(partition) > 2:
128 | c = get_id()
129 | times[c] = t
130 | parents[c] = p
131 | parents[partition.pop()] = p
132 | p = c
133 | parents[partition.pop()] = p
134 | parents[partition.pop()] = p
135 | assert not partition
136 |
137 | return Phylogeny.from_unsorted(times, parents, leaves)
138 |
--------------------------------------------------------------------------------
/pyrocov/special.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | # Adapted from @viswackftw
5 | # https://github.com/pytorch/pytorch/issues/52973#issuecomment-787587188
6 |
7 | import math
8 |
9 | import torch
10 |
11 |
12 | def ndtr(value: torch.Tensor):
13 | """
14 | Based on the SciPy implementation of ndtr
15 | """
16 | sqrt_half = torch.sqrt(torch.tensor(0.5, dtype=value.dtype))
17 | x = value * sqrt_half
18 | z = abs(x)
19 | y = 0.5 * torch.erfc(z)
20 | output = torch.where(
21 | z < sqrt_half, 0.5 + 0.5 * torch.erf(x), torch.where(x > 0, 1 - y, y)
22 | )
23 | return output
24 |
25 |
26 | def log_ndtr(value: torch.Tensor):
27 | """
28 | Function to compute the log of the normal CDF at value.
29 | This is based on the TFP implementation.
30 | """
31 | dtype = value.dtype
32 | if dtype == torch.float64:
33 | lower, upper = -20, 8
34 | elif dtype == torch.float32:
35 | lower, upper = -10, 5
36 | else:
37 | raise TypeError("value needs to be either float32 or float64")
38 |
39 | # When x < lower, then we perform a fixed series expansion (asymptotic)
40 | # = log(cdf(x)) = log(1 - cdf(-x)) = log(1 / 2 * erfc(-x / sqrt(2)))
41 | # = log(-1 / sqrt(2 * pi) * exp(-x ** 2 / 2) / x * (1 + sum))
42 | # When x >= lower and x <= upper, then we simply perform log(cdf(x))
43 | # When x > upper, then we use the approximation log(cdf(x)) = log(1 - cdf(-x)) \approx -cdf(-x)
44 | return torch.where(
45 | value > upper,
46 | torch.log1p(-ndtr(-value)),
47 | torch.where(value >= lower, torch.log(ndtr(value)), log_ndtr_series(value)),
48 | )
49 |
50 |
51 | def log_ndtr_series(value: torch.Tensor, num_terms=3):
52 | """
53 | Function to compute the asymptotic series expansion of the log of normal CDF
54 | at value.
55 | This is based on the TFP implementation.
56 | """
57 | # sum = sum_{n=1}^{num_terms} (-1)^{n} (2n - 1)!! / x^{2n}))
58 | value_sq = value**2
59 | t1 = -0.5 * (math.log(2 * math.pi) + value_sq) - torch.log(-value)
60 | t2 = torch.zeros_like(value)
61 | value_even_power = value_sq.clone()
62 | double_fac = 1
63 | multiplier = -1
64 | for n in range(1, num_terms + 1):
65 | t2.add_(multiplier * double_fac / value_even_power)
66 | value_even_power.mul_(value_sq)
67 | double_fac *= 2 * n - 1
68 | multiplier *= -1
69 | return t1 + torch.log1p(t2)
70 |
--------------------------------------------------------------------------------
/pyrocov/stats.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import numpy as np
5 | import torch
6 | from scipy.special import log_ndtr
7 |
8 |
9 | def hpd_interval(p: float, samples: torch.Tensor):
10 | assert 0.5 < p < 1
11 | assert samples.shape
12 | pad = int(round((1 - p) * len(samples)))
13 | assert pad > 0, "too few samples"
14 | width = samples[-pad:] - samples[:pad]
15 | lb = width.max(0).indices
16 | ub = len(samples) - lb - 1
17 | i = torch.stack([lb, ub])
18 | return samples.gather(0, i)
19 |
20 |
21 | def confidence_interval(p: float, samples: torch.Tensor):
22 | assert 0.5 < p < 1
23 | assert samples.shape
24 | pad = (1 - p) / 2
25 | lk = int(round(pad * (len(samples) - 1)))
26 | uk = int(round((1 - pad) * (len(samples) - 1)))
27 | assert pad > 0, "too few samples"
28 | lb = samples.kthvalue(lk, 0).values
29 | ub = samples.kthvalue(uk, 0).values
30 | return torch.stack([lb, ub])
31 |
32 |
33 | def normal_log10bf(mean, std=1.0):
34 | """
35 | Returns ``log10(P[x>0] / P[x<0])`` for ``x ~ N(mean, std)``.
36 | """
37 | z = mean / std
38 | return (log_ndtr(z) - log_ndtr(-z)) / np.log(10)
39 |
--------------------------------------------------------------------------------
/pyrocov/substitution.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pyro.distributions as dist
5 | import torch
6 | from pyro.nn import PyroModule, PyroSample, pyro_method
7 |
8 |
9 | class SubstitutionModel(PyroModule):
10 | """
11 | Probabilistic substitution model among a finite number of states
12 | (typically 4 for nucleotides or 20 for amino acids).
13 |
14 | This returns a continuous time transition matrix.
15 | """
16 |
17 | @pyro_method
18 | def matrix_exp(self, dt):
19 | m = self().to(dt.dtype)
20 | return (m * dt[:, None, None]).matrix_exp()
21 |
22 | @pyro_method
23 | def log_matrix_exp(self, dt):
24 | m = self.matrix_exp(dt)
25 | m.data.clamp_(torch.finfo(m.dtype).eps)
26 | return m.log()
27 |
28 |
29 | class JukesCantor69(SubstitutionModel):
30 | """
31 | A simple uniform substition model with a single latent rate parameter.
32 |
33 | This provides a weak Exponential(1) prior over the rate parameter.
34 |
35 | [1] T.H. Jukes, C.R. Cantor (1969) "Evolution of protein molecules"
36 | [2] https://en.wikipedia.org/wiki/Models_of_DNA_evolution#JC69_model_(Jukes_and_Cantor_1969)
37 | """
38 |
39 | def __init__(self, *, dim=4):
40 | assert isinstance(dim, int) and dim > 0
41 | super().__init__()
42 | self.dim = dim
43 | self.rate = PyroSample(dist.Exponential(1.0))
44 |
45 | def forward(self):
46 | D = self.dim
47 | return self.rate * (1.0 / D - torch.eye(D))
48 |
49 | @pyro_method
50 | def matrix_exp(self, dt):
51 | D = self.dim
52 | rate = torch.as_tensor(self.rate, dtype=dt.dtype)
53 | p = dt.mul(-rate).exp()[:, None, None]
54 | q = (1 - p) / D
55 | return torch.where(torch.eye(D, dtype=torch.bool), p + q, q)
56 |
57 | @pyro_method
58 | def log_matrix_exp(self, dt):
59 | D = self.dim
60 | rate = torch.as_tensor(self.rate, dtype=dt.dtype)
61 | p = dt.mul(-rate).exp()[:, None, None]
62 | q = (1 - p) / D
63 | q.data.clamp_(min=torch.finfo(q.dtype).eps)
64 | on_diag = (p + q).log()
65 | off_diag = q.log()
66 | return torch.where(torch.eye(D, dtype=torch.bool), on_diag, off_diag)
67 |
68 |
69 | class GeneralizedTimeReversible(SubstitutionModel):
70 | """
71 | Generalized time-reversible substitution model among ``dim``-many states.
72 |
73 | This provides a weak Dirichlet(2) prior over the steady state distribution
74 | and a weak Exponential(1) prior over mutation rates.
75 | """
76 |
77 | def __init__(self, *, dim=4):
78 | assert isinstance(dim, int) and dim > 0
79 | super().__init__()
80 | self.dim = dim
81 | self.stationary = PyroSample(dist.Dirichlet(torch.full((dim,), 2.0)))
82 | self.rates = PyroSample(
83 | dist.Exponential(torch.ones(dim * (dim - 1) // 2)).to_event(1)
84 | )
85 | i = torch.arange(dim)
86 | self._index = (i > i[:, None]).nonzero(as_tuple=False).T
87 |
88 | def forward(self):
89 | p = self.stationary
90 | i, j = self._index
91 | m = torch.zeros(self.dim, self.dim)
92 | m[i, j] = self.rates
93 | m = m + m.T * (p / p[:, None])
94 | m = m - m.sum(dim=-1).diag_embed()
95 | return m
96 |
--------------------------------------------------------------------------------
/pyrocov/util.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import functools
5 | import gzip
6 | import itertools
7 | import operator
8 | import os
9 | import weakref
10 | from typing import Dict
11 |
12 | import pyro
13 | import torch
14 | import tqdm
15 | from torch.distributions import constraints, transform_to
16 |
17 |
18 | def pearson_correlation(x: torch.Tensor, y: torch.Tensor):
19 | x = (x - x.mean()) / x.std()
20 | y = (y - x.mean()) / y.std()
21 | return (x * y).mean()
22 |
23 |
24 | def pyro_param(name, shape, constraint=constraints.real):
25 | transform = transform_to(constraint)
26 | terms = []
27 | for subshape in itertools.product(*({1, int(size)} for size in shape)):
28 | subname = "_".join([name] + list(map(str, subshape)))
29 | subinit = functools.partial(torch.zeros, subshape)
30 | terms.append(pyro.param(subname, subinit))
31 | unconstrained = functools.reduce(operator.add, terms)
32 | return transform(unconstrained)
33 |
34 |
35 | def quotient_central_moments(
36 | fine_values: torch.Tensor, fine_to_coarse: torch.Tensor
37 | ) -> torch.Tensor:
38 | """
39 | Returns (zeroth, first, second) central momemnts of each coarse cluster of
40 | fine values, i.e. (count, mean, stddev).
41 |
42 | :returns: A single stacked tensor of shape ``(3,) + fine_values.shape``.
43 | """
44 | C = 1 + int(fine_to_coarse.max())
45 | moments = torch.zeros(3, C)
46 | moments[0].scatter_add_(0, fine_to_coarse, torch.ones_like(fine_values))
47 | moments[1].scatter_add_(0, fine_to_coarse, fine_values)
48 | moments[1] /= moments[0]
49 | fine_diff2 = (fine_values - moments[1][fine_to_coarse]).square()
50 | moments[2].scatter_add_(0, fine_to_coarse, fine_diff2)
51 | moments[2] /= moments[0]
52 | moments[2].sqrt_()
53 | return moments
54 |
55 |
56 | def weak_memoize_by_id(fn):
57 | cache = {}
58 | missing = object() # An arbitrary value that cannot be returned by fn.
59 |
60 | @functools.wraps(fn)
61 | def memoized_fn(*args):
62 | key = tuple(map(id, args))
63 | result = cache.get(key, missing)
64 | if result is missing:
65 | result = cache[key] = fn(*args)
66 | for arg in args:
67 | # Register callbacks only for types that support weakref.
68 | if type(arg).__weakrefoffset__:
69 | weakref.finalize(arg, cache.pop, key, None)
70 | return result
71 |
72 | return memoized_fn
73 |
74 |
75 | _TENSORS: Dict[tuple, torch.Tensor] = {}
76 |
77 |
78 | def deduplicate_tensor(x):
79 | key = x.dtype, x.stride(), x.data_ptr()
80 | return _TENSORS.setdefault(key, x)
81 |
82 |
83 | def torch_map(x, **kwargs):
84 | """
85 | Calls ``leaf.to(**kwargs)`` on all tensor and module leaves of a nested
86 | data structure.
87 | """
88 | return _torch_map(x, **kwargs)[0]
89 |
90 |
91 | @functools.singledispatch
92 | def _torch_map(x, **kwargs):
93 | return x, False
94 |
95 |
96 | @_torch_map.register(torch.Tensor)
97 | def _torch_map_tensor(x, **kwargs):
98 | x_ = x.to(**kwargs)
99 | changed = x_ is not x
100 | return x_, changed
101 |
102 |
103 | @_torch_map.register(torch.nn.Module)
104 | def _torch_map_module(x, **kwargs):
105 | changed = True # safe
106 | return x.to(**kwargs), changed
107 |
108 |
109 | @_torch_map.register(dict)
110 | def _torch_map_dict(x, **kwargs):
111 | result = type(x)()
112 | changed = False
113 | for k, v in x.items():
114 | v, v_changed = _torch_map(v, **kwargs)
115 | result[k] = v
116 | changed = changed or v_changed
117 | return (result, True) if changed else (x, False)
118 |
119 |
120 | @_torch_map.register(list)
121 | @_torch_map.register(tuple)
122 | def _torch_map_iterable(x, **kwargs):
123 | result = []
124 | changed = False
125 | for v in x:
126 | v, v_changed = _torch_map(v, **kwargs)
127 | result.append(v)
128 | changed = changed or v_changed
129 | result = type(x)(result)
130 | return (result, True) if changed else (x, False)
131 |
132 |
133 | def pretty_print(x, *, name="", max_items=10):
134 | if isinstance(x, (int, float, str, bool)):
135 | print(f"{name} = {repr(x)}")
136 | elif isinstance(x, torch.Tensor):
137 | print(f"{name}: {type(x).__name__} of shape {tuple(x.shape)}")
138 | elif isinstance(x, (tuple, list)):
139 | print(f"{name}: {type(x).__name__} of length {len(x)}")
140 | elif isinstance(x, dict):
141 | print(f"{name}: {type(x).__name__} of length {len(x)}")
142 | if len(x) <= max_items:
143 | for k, v in x.items():
144 | pretty_print(v, name=f"{name}[{repr(k)}]", max_items=max_items)
145 | else:
146 | print(f"{name}: {type(x).__name__}")
147 |
148 |
149 | def generate_colors(num_points=100, lb=0.5, ub=2.5):
150 | """
151 | Constructs a quasirandom collection of colors for plotting.
152 | """
153 | # http://extremelearning.com.au/unreasonable-effectiveness-of-quasirandom-sequences/
154 | phi3 = 1.2207440846
155 | alpha = torch.tensor([1 / phi3**3, 1 / phi3**2, 1 / phi3])
156 | t = torch.arange(float(2 * num_points))
157 | rgb = alpha.mul(t[:, None]).add(torch.tensor([0.8, 0.2, 0.1])).fmod(1)
158 | total = rgb.sum(-1)
159 | rgb = rgb[(lb <= total) & (total <= ub)]
160 | rgb = rgb[:num_points]
161 | assert len(rgb) == num_points
162 | return [f"#{r:02x}{g:02x}{b:02x}" for r, g, b in rgb.mul(256).long().tolist()]
163 |
164 |
165 | def open_tqdm(*args, **kwargs):
166 | with open(*args, **kwargs) as f:
167 | with tqdm.tqdm(
168 | total=os.stat(f.fileno()).st_size,
169 | unit="B",
170 | unit_scale=True,
171 | unit_divisor=1024,
172 | smoothing=0,
173 | ) as pbar:
174 | for line in f:
175 | pbar.update(len(line))
176 | yield line
177 |
178 |
179 | def gzip_open_tqdm(filename, mode="rb"):
180 | with open(filename, "rb") as f, gzip.open(f, mode) as g:
181 | with tqdm.tqdm(
182 | total=os.stat(f.fileno()).st_size,
183 | unit="B",
184 | unit_scale=True,
185 | unit_divisor=1024,
186 | smoothing=0,
187 | ) as pbar:
188 | for line in g:
189 | pbar.n = f.tell()
190 | pbar.update(0)
191 | yield line
192 |
--------------------------------------------------------------------------------
/scripts/fix_columns.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import glob
5 | import logging
6 | import os
7 | import pickle
8 |
9 | import tqdm
10 |
11 | from pyrocov.geo import gisaid_normalize
12 |
13 | logger = logging.getLogger(__name__)
14 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
15 |
16 |
17 | def main():
18 | """
19 | Fixes columns["location"] via gisaid_normalize().
20 | """
21 | tempfile = "results/temp.columns.pkl"
22 | for infile in glob.glob("results/*.columns.pkl"):
23 | if "temp" in infile:
24 | continue
25 | logger.info(f"Processing {infile}")
26 | with open(infile, "rb") as f:
27 | columns = pickle.load(f)
28 | columns["location"] = [
29 | gisaid_normalize(x) for x in tqdm.tqdm(columns["location"])
30 | ]
31 | with open(tempfile, "wb") as f:
32 | pickle.dump(columns, f)
33 | os.rename(tempfile, infile) # atomic
34 |
35 |
36 | if __name__ == "__main__":
37 | main()
38 |
--------------------------------------------------------------------------------
/scripts/git_pull.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import os
5 | import sys
6 | from subprocess import check_call
7 |
8 | # This keeps repos organized as ~/github/{user}/{repo}
9 | GITHUB = os.path.expanduser(os.path.join("~", "github"))
10 | if not os.path.exists(GITHUB):
11 | os.makedirs(GITHUB)
12 |
13 | update = True
14 | for arg in sys.argv[1:]:
15 | if arg == "--no-update":
16 | update = False
17 | continue
18 |
19 | try:
20 | user, repo = arg.split("/")
21 | except Exception:
22 | raise ValueError(
23 | f"Expected args of the form username/repo e.g. pyro-ppl/pyro, but got {arg}"
24 | )
25 |
26 | dirname = os.path.join(GITHUB, user)
27 | if not os.path.exists(dirname):
28 | os.makedirs(dirname)
29 | os.chdir(dirname)
30 | dirname = os.path.join(dirname, repo)
31 | if not os.path.exists(dirname):
32 | print(f"Cloning {arg}")
33 | check_call(
34 | ["git", "clone", "--depth", "1", f"https://github.com/{user}/{repo}"]
35 | )
36 | elif update:
37 | print(f"Pulling {arg}")
38 | os.chdir(dirname)
39 | check_call(["git", "pull"])
40 |
--------------------------------------------------------------------------------
/scripts/moran.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import argparse
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 | import pandas as pd
9 | import torch
10 |
11 | from pyrocov.sarscov2 import aa_mutation_to_position
12 |
13 |
14 | # compute moran statistic
15 | def moran(values, distances, lengthscale):
16 | assert values.size(-1) == distances.size(-1)
17 | weights = (distances.unsqueeze(-1) - distances.unsqueeze(-2)) / lengthscale
18 | weights = torch.exp(-weights.pow(2.0))
19 | weights *= 1.0 - torch.eye(weights.size(-1))
20 | weights /= weights.sum(-1, keepdim=True)
21 |
22 | output = torch.einsum("...ij,...i,...j->...", weights, values, values)
23 | return output / values.pow(2.0).sum(-1)
24 |
25 |
26 | # compute moran statistic and do permutation test with given number of permutations
27 | def permutation_test(values, distances, lengthscale, num_perm=999):
28 | values = values - values.mean()
29 | moran_given = moran(values, distances, lengthscale).item()
30 | idx = [torch.randperm(distances.size(-1)) for _ in range(num_perm)]
31 | idx = torch.stack(idx)
32 | moran_perm = moran(values[idx], distances, lengthscale)
33 | p_value = (moran_perm >= moran_given).sum().item() + 1
34 | p_value /= float(num_perm + 1)
35 | return moran_given, p_value
36 |
37 |
38 | def main(args):
39 | # read in inferred mutations
40 | df = pd.read_csv("paper/mutations.tsv", sep="\t", index_col=0)
41 | df = df[["mutation", "Δ log R"]]
42 | mutations = df.values[:, 0]
43 | assert mutations.shape == (2904,)
44 | coefficients = df.values[:, 1] if not args.magnitude else np.abs(df.values[:, 1])
45 | gene_map = defaultdict(list)
46 | distance_map = defaultdict(list)
47 |
48 | results = []
49 |
50 | # collect indices and nucleotide positions corresponding to each mutation
51 | for i, m in enumerate(mutations):
52 | gene = m.split(":")[0]
53 | gene_map[gene].append(i)
54 | distance_map[gene].append(aa_mutation_to_position(m))
55 |
56 | # map over each gene
57 | for gene, idx in gene_map.items():
58 | values = torch.from_numpy(np.array(coefficients[idx], dtype=np.float32))
59 | distances = distance_map[gene]
60 | distances = torch.from_numpy(np.array(distances) - min(distances))
61 | gene_size = distances.max().item()
62 | lengthscale = min(gene_size / 20, 50.0)
63 | _, p_value = permutation_test(values, distances, lengthscale, num_perm=999999)
64 | s = "Gene: {} \t #Mut: {} Size: {} \t p-value: {:.6f} Lengthscale: {:.1f}"
65 | print(s.format(gene, distances.size(0), gene_size, p_value, lengthscale))
66 | results.append([distances.size(0), gene_size, p_value, lengthscale])
67 |
68 | # compute moran statistic for entire genome for mulitple lengthscales
69 | for global_lengthscale in [100.0, 500.0]:
70 | distances_ = [aa_mutation_to_position(m) for m in mutations]
71 | distances = torch.from_numpy(
72 | np.array(distances_, dtype=np.float32) - min(distances_)
73 | )
74 | values = torch.tensor(np.array(coefficients, dtype=np.float32)).float()
75 | _, p_value = permutation_test(
76 | values, distances, global_lengthscale, num_perm=999999
77 | )
78 | genome_size = distances.max().item()
79 | s = "Entire Genome (#Mut = {}; Size = {}): \t p-value: {:.6f} Lengthscale: {:.1f}"
80 | print(s.format(distances.size(0), genome_size, p_value, global_lengthscale))
81 | results.append([distances.size(0), genome_size, p_value, global_lengthscale])
82 |
83 | # save results as csv
84 | results = np.stack(results)
85 | columns = ["NumMutations", "GeneSize", "PValue", "Lengthscale"]
86 | index = list(gene_map.keys()) + ["EntireGenome"] * 2
87 | result = pd.DataFrame(data=results, index=index, columns=columns)
88 | result.sort_values(["PValue"]).to_csv("paper/moran.csv")
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser(description="Compute moran statistics")
93 | parser.add_argument("--magnitude", action="store_true")
94 | args = parser.parse_args()
95 | main(args)
96 |
--------------------------------------------------------------------------------
/scripts/preprocess_credits.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import argparse
5 | import logging
6 | import lzma
7 | import pickle
8 |
9 | logger = logging.getLogger(__name__)
10 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
11 |
12 |
13 | def main(args):
14 | logger.info(f"Loading {args.columns_file_in}")
15 | with open(args.columns_file_in, "rb") as f:
16 | columns = pickle.load(f)
17 |
18 | logger.info(f"Saving {args.credits_file_out}")
19 | with lzma.open(args.credits_file_out, "wt") as f:
20 | f.write("\n".join(sorted(columns["index"])))
21 |
22 |
23 | if __name__ == "__main__":
24 | parser = argparse.ArgumentParser(description="Save gisaid accession numbers")
25 | parser.add_argument("---columns-file-in", default="results/columns.3000.pkl")
26 | parser.add_argument("--credits-file-out", default="paper/accession_ids.txt.xz")
27 | args = parser.parse_args()
28 | main(args)
29 |
--------------------------------------------------------------------------------
/scripts/preprocess_gisaid.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import argparse
5 | import datetime
6 | import json
7 | import logging
8 | import os
9 | import pickle
10 | import warnings
11 | from collections import Counter, defaultdict
12 |
13 | from pyrocov import pangolin
14 | from pyrocov.geo import gisaid_normalize
15 | from pyrocov.mutrans import START_DATE
16 | from pyrocov.util import open_tqdm
17 |
18 | logger = logging.getLogger(__name__)
19 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
20 |
21 | DATE_FORMATS = {4: "%Y", 7: "%Y-%m", 10: "%Y-%m-%d"}
22 |
23 |
24 | def parse_date(string):
25 | fmt = DATE_FORMATS.get(len(string))
26 | if fmt is None:
27 | # Attempt to fix poorly formated dates like 2020-09-1.
28 | parts = string.split("-")
29 | parts = parts[:1] + [f"{int(p):>02d}" for p in parts[1:]]
30 | string = "-".join(parts)
31 | fmt = DATE_FORMATS[len(string)]
32 | return datetime.datetime.strptime(string, fmt)
33 |
34 |
35 | FIELDS = ["virus_name", "accession_id", "collection_date", "location", "add_location"]
36 |
37 |
38 | def main(args):
39 | logger.info(f"Filtering {args.gisaid_file_in}")
40 | if not os.path.exists(args.gisaid_file_in):
41 | raise OSError(f"Missing {args.gisaid_file_in}; you may need to request a feed")
42 | os.makedirs("results", exist_ok=True)
43 |
44 | columns = defaultdict(list)
45 | stats = defaultdict(Counter)
46 | covv_fields = ["covv_" + key for key in FIELDS]
47 |
48 | for i, line in enumerate(open_tqdm(args.gisaid_file_in)):
49 | # Optimize for faster reading.
50 | line, _ = line.split(', "sequence": ', 1)
51 | line += "}"
52 |
53 | # Filter out bad data.
54 | datum = json.loads(line)
55 | if len(datum["covv_collection_date"]) < 7:
56 | continue # Drop rows with no month information.
57 | date = parse_date(datum["covv_collection_date"])
58 | if date < args.start_date:
59 | date = args.start_date # Clip rows before start date.
60 | lineage = datum["covv_lineage"]
61 | if lineage in (None, "None", ""):
62 | continue # Drop rows with unknown lineage.
63 | try:
64 | lineage = pangolin.compress(lineage)
65 | lineage = pangolin.decompress(lineage)
66 | assert lineage
67 | except (ValueError, AssertionError) as e:
68 | warnings.warn(str(e))
69 | continue
70 |
71 | # Fix duplicate locations.
72 | datum["covv_location"] = gisaid_normalize(datum["covv_location"])
73 |
74 | # Collate.
75 | columns["lineage"].append(lineage)
76 | for covv_key, key in zip(covv_fields, FIELDS):
77 | columns[key].append(datum[covv_key])
78 | columns["day"].append((date - args.start_date).days)
79 |
80 | # Aggregate statistics.
81 | stats["date"][datum["covv_collection_date"]] += 1
82 | stats["location"][datum["covv_location"]] += 1
83 | stats["lineage"][lineage] += 1
84 |
85 | if i >= args.truncate:
86 | break
87 |
88 | num_dropped = i + 1 - len(columns["day"])
89 | logger.info(f"dropped {num_dropped}/{i+1} = {num_dropped*100/(i+1):0.2g}% rows")
90 |
91 | logger.info(f"saving {args.columns_file_out}")
92 | with open(args.columns_file_out, "wb") as f:
93 | pickle.dump(dict(columns), f)
94 |
95 | logger.info(f"saving {args.stats_file_out}")
96 | with open(args.stats_file_out, "wb") as f:
97 | pickle.dump(dict(stats), f)
98 |
99 |
100 | if __name__ == "__main__":
101 | parser = argparse.ArgumentParser(description="Preprocess GISAID data")
102 | parser.add_argument("--gisaid-file-in", default="results/gisaid.json")
103 | parser.add_argument("--columns-file-out", default="results/gisaid.columns.pkl")
104 | parser.add_argument("--stats-file-out", default="results/gisaid.stats.pkl")
105 | parser.add_argument("--start-date", default=START_DATE)
106 | parser.add_argument("--truncate", default=int(1e10), type=int)
107 | args = parser.parse_args()
108 | args.start_date = parse_date(args.start_date)
109 | main(args)
110 |
--------------------------------------------------------------------------------
/scripts/preprocess_nextstrain.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | """
5 | Preprocess Nextstrain open data.
6 |
7 | This script aggregates the metadata.tsv.gz file available from:
8 | https://docs.nextstrain.org/projects/ncov/en/latest/reference/remote_inputs.html
9 | This file is mirrored on S3 and GCP:
10 | https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz
11 | s3://nextstrain-data/files/ncov/open/metadata.tsv.gz
12 | gs://nextstrain-data/files/ncov/open/metadata.tsv.gz
13 | """
14 |
15 | import argparse
16 | import datetime
17 | import logging
18 | import pickle
19 | from collections import Counter, defaultdict
20 |
21 | import torch
22 |
23 | from pyrocov.growth import START_DATE, dense_to_sparse
24 | from pyrocov.util import gzip_open_tqdm
25 |
26 | logger = logging.getLogger(__name__)
27 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
28 |
29 |
30 | def parse_date(string):
31 | return datetime.datetime.strptime(string, "%Y-%m-%d")
32 |
33 |
34 | def coarsen_locations(args, counts):
35 | """
36 | Select regions that have at least ``args.min_region_size`` samples.
37 | Remaining regions will be coarsely aggregated up to country level.
38 | """
39 | locations = set()
40 | coarsen_location = {}
41 | for location, count in counts.items():
42 | if " / " in location and counts[location] < args.min_region_size:
43 | old = location
44 | location = location.split(" / ")[0]
45 | coarsen_location[old] = location
46 | locations.add(location)
47 | locations = sorted(locations)
48 | logger.info(f"kept {len(locations)}/{len(counts)} locations")
49 | return locations, coarsen_location
50 |
51 |
52 | def main(args):
53 | columns = defaultdict(list)
54 | stats = defaultdict(Counter)
55 | skipped = Counter()
56 |
57 | # Process rows one at a time.
58 | logger.info(f"Reading {args.metadata_file_in}")
59 | header = None
60 | for line in gzip_open_tqdm(args.metadata_file_in, "rt"):
61 | line = line.strip().split("\t")
62 | if header is None:
63 | header = line
64 | continue
65 | row = dict(zip(header, line))
66 |
67 | # Parse date.
68 | try:
69 | date = parse_date(row["date"])
70 | except ValueError:
71 | skipped["date"] += 1
72 | continue
73 | day = (date - args.start_date).days
74 |
75 | # Parse location.
76 | location = row["country"]
77 | if location in ("", "?"):
78 | skipped["location"] += 1
79 | continue
80 | division = row["division"]
81 | if division not in ("", "?"):
82 | location += " / " + division
83 |
84 | # Parse lineage.
85 | lineage = row["pango_lineage"]
86 | if lineage in ("", "?", "unclassifiable"):
87 | skipped["lineage"] += 1
88 | continue
89 | assert lineage[0] in "ABCDEFGHIJKLMNOPQRSTUVWXYZ", lineage
90 |
91 | # Append row.
92 | columns["day"].append(day)
93 | columns["location"].append(location)
94 | columns["lineage"].append(lineage)
95 |
96 | # Record stats.
97 | stats["day"][day] += 1
98 | stats["location"][location] += 1
99 | stats["lineage"][lineage] += 1
100 | for aa in row["aaSubstitutions"].split(","):
101 | stats["aa"][aa] += 1
102 | stats["lineage_aa"][lineage, aa] += 1
103 | columns = dict(columns)
104 | stats = dict(stats)
105 | logger.info(f"kept {len(columns['location'])} rows")
106 | logger.info(f"skipped {sum(skipped.values())} due to:\n{dict(skipped)}")
107 | for k, v in stats.items():
108 | logger.info(f"found {len(v)} {k}s")
109 |
110 | logger.info(f"saving {args.stats_file_out}")
111 | with open(args.stats_file_out, "wb") as f:
112 | pickle.dump(stats, f)
113 |
114 | logger.info(f"saving {args.columns_file_out}")
115 | with open(args.columns_file_out, "wb") as f:
116 | pickle.dump(columns, f)
117 |
118 | # Create contiguous coordinates.
119 | locations = sorted(stats["location"])
120 | lineages = sorted(stats["lineage"])
121 | aa_counts = Counter()
122 | for (lineage, aa), count in stats["lineage_aa"].most_common():
123 | if count * 2 >= stats["lineage"][lineage]:
124 | aa_counts[aa] += count
125 | logger.info(f"kept {len(aa_counts)}/{len(stats['aa'])} aa substitutions")
126 | aa_mutations = [aa for aa, _ in aa_counts.most_common()]
127 |
128 | # Create a dense feature matrix.
129 | aa_features = torch.zeros(len(lineages), len(aa_mutations), dtype=torch.float)
130 | logger.info(
131 | f"saving {tuple(aa_features.shape)} features to {args.features_file_out}"
132 | )
133 | for s, lineage in enumerate(lineages):
134 | for f, aa in enumerate(aa_mutations):
135 | count = stats["lineage_aa"].get((lineage, aa))
136 | if count is None:
137 | continue
138 | aa_features[s, f] = count / stats["lineage"][lineage]
139 | features = {
140 | "lineages": lineages,
141 | "aa_mutations": aa_mutations,
142 | "aa_features": aa_features,
143 | }
144 | with open(args.features_file_out, "wb") as f:
145 | torch.save(features, f)
146 |
147 | # Create a dense dataset.
148 | locations, coarsen_location = coarsen_locations(args, stats["location"])
149 | location_id = {location: i for i, location in enumerate(locations)}
150 | lineage_id = {lineage: i for i, lineage in enumerate(lineages)}
151 | T = max(stats["day"]) // args.time_step_days + 1
152 | P = len(locations)
153 | S = len(lineages)
154 | counts = torch.zeros(T, P, S)
155 | for day, location, lineage in zip(
156 | columns["day"], columns["location"], columns["lineage"]
157 | ):
158 | location = coarsen_location.get(location, location)
159 | t = day // args.time_step_days
160 | p = location_id[location]
161 | s = lineage_id[lineage]
162 | counts[t, p, s] += 1
163 | logger.info(f"counts data is {counts.ne(0).float().mean().item()*100:0.3g}% dense")
164 | sparse_counts = dense_to_sparse(counts)
165 | place_lineage_index = counts.ne(0).any(0).reshape(-1).nonzero(as_tuple=True)[0]
166 | logger.info(f"saving {tuple(counts.shape)} counts to {args.dataset_file_out}")
167 | dataset = {
168 | "start_date": args.start_date,
169 | "time_step_days": args.time_step_days,
170 | "locations": locations,
171 | "lineages": lineages,
172 | "mutations": aa_mutations,
173 | "features": aa_features,
174 | "weekly_counts": counts,
175 | "sparse_counts": sparse_counts,
176 | "place_lineage_index": place_lineage_index,
177 | }
178 | with open(args.dataset_file_out, "wb") as f:
179 | torch.save(dataset, f)
180 |
181 |
182 | if __name__ == "__main__":
183 | parser = argparse.ArgumentParser(description=__doc__)
184 | parser.add_argument(
185 | "--metadata-file-in", default="results/nextstrain/metadata.tsv.gz"
186 | )
187 | parser.add_argument("--columns-file-out", default="results/nextstrain.columns.pkl")
188 | parser.add_argument("--stats-file-out", default="results/nextstrain.stats.pkl")
189 | parser.add_argument("--features-file-out", default="results/nextstrain.features.pt")
190 | parser.add_argument("--dataset-file-out", default="results/nextstrain.data.pt")
191 | parser.add_argument("--start-date", default=START_DATE)
192 | parser.add_argument("--time-step-days", default=14, type=int)
193 | parser.add_argument("--min-region-size", default=50, type=int)
194 | args = parser.parse_args()
195 | args.start_date = parse_date(args.start_date)
196 | main(args)
197 |
--------------------------------------------------------------------------------
/scripts/pull_gisaid.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh -ex
2 |
3 | # Ensure data directory (or a link) exists.
4 | test -e results || mkdir results
5 |
6 | # Download.
7 | curl -u $GISAID_USERNAME:$GISAID_PASSWORD --retry 4 \
8 | https://www.epicov.org/epi3/3p/$GISAID_FEED/export/provision.json.xz \
9 | > results/gisaid.json.xz
10 |
11 | # Decompress, keeping the original.
12 | xz -d -k -f -T0 -v results/gisaid.json.xz
13 |
--------------------------------------------------------------------------------
/scripts/pull_nextstrain.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh +ex
2 |
3 | mkdir -p results/nextstrain
4 |
5 | curl https://data.nextstrain.org/files/ncov/open/metadata.tsv.gz -o results/nextstrain/metadata.tsv.gz
6 |
7 | gunzip -kvf results/nextstrain/metadata.tsv.gz
8 |
--------------------------------------------------------------------------------
/scripts/pull_usher.sh:
--------------------------------------------------------------------------------
1 | #/bin/sh +ex
2 |
3 | mkdir -p results/usher
4 |
5 | # These are the available files:
6 | # public-latest.version.txt
7 | # public-latest.metadata.tsv.gz
8 | # public-latest.all.masked.pb.gz
9 | # public-latest.all.masked.vcf.gz
10 | # public-latest.all.nwk.gz
11 | url='http://hgdownload.soe.ucsc.edu/goldenPath/wuhCor1/UShER_SARS-CoV-2'
12 | for name in version.txt metadata.tsv.gz all.masked.pb.gz
13 | do
14 | curl $url/public-latest.$name -o results/usher/$name
15 | done
16 |
17 | gunzip -kvf results/usher/*.gz
18 |
--------------------------------------------------------------------------------
/scripts/rank_mutations.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import argparse
5 | import functools
6 | import logging
7 | import math
8 | import os
9 |
10 | import torch
11 | from pyro import poutine
12 |
13 | from pyrocov import mutrans
14 |
15 | logger = logging.getLogger(__name__)
16 | logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
17 |
18 |
19 | def cached(filename):
20 | def decorator(fn):
21 | @functools.wraps(fn)
22 | def cached_fn(*args, **kwargs):
23 | f = filename(*args, **kwargs) if callable(filename) else filename
24 | if args[0].force or not os.path.exists(f):
25 | result = fn(*args, **kwargs)
26 | logger.info(f"saving {f}")
27 | torch.save(result, f)
28 | else:
29 | logger.info(f"loading cached {f}")
30 | result = torch.load(f, map_location=args[0].device)
31 | return result
32 |
33 | return cached_fn
34 |
35 | return decorator
36 |
37 |
38 | @cached("results/mutrans.data.pt")
39 | def load_data(args):
40 | return mutrans.load_gisaid_data(device=args.device)
41 |
42 |
43 | @cached("results/rank_mutations.rank_mf_svi.pt")
44 | def rank_mf_svi(args, dataset):
45 | result = mutrans.fit_mf_svi(
46 | dataset,
47 | mutrans.model,
48 | learning_rate=args.svi_learning_rate,
49 | num_steps=args.svi_num_steps,
50 | log_every=args.log_every,
51 | seed=args.seed,
52 | )
53 | result["args"] = (args,)
54 | sigma = result["mean"] / result["std"]
55 | result["ranks"] = sigma.sort(0, descending=True).indices
56 | result["cond_data"] = {
57 | "feature_scale": result["median"]["feature_scale"].item(),
58 | "concentration": result["median"]["concentration"].item(),
59 | }
60 | del result["guide"]
61 | return result
62 |
63 |
64 | @cached("results/rank_mutations.rank_full_svi.pt")
65 | def rank_full_svi(args, dataset):
66 | result = mutrans.fit_full_svi(
67 | dataset,
68 | mutrans.model,
69 | learning_rate=args.full_learning_rate,
70 | learning_rate_decay=args.full_learning_rate_decay,
71 | num_steps=args.full_num_steps,
72 | log_every=args.log_every,
73 | seed=args.seed,
74 | )
75 | result["args"] = (args,)
76 | result["mean"] = result["params"]["rate_coef_loc"]
77 | scale_tril = result["params"]["rate_coef_scale_tril"]
78 | result["cov"] = scale_tril @ scale_tril.T
79 | result["var"] = result["cov"].diag()
80 | result["std"] = result["var"].sqrt()
81 | sigma = result["mean"] / result["std"]
82 | result["ranks"] = sigma.sort(0, descending=True).indices
83 | result["cond_data"] = {
84 | "feature_scale": result["median"]["feature_scale"].item(),
85 | "concentration": result["median"]["concentration"].item(),
86 | }
87 | return result
88 |
89 |
90 | @cached("results/rank_mutations.hessian.pt")
91 | def compute_hessian(args, dataset, result):
92 | logger.info("Computing Hessian")
93 | features = dataset["features"]
94 | weekly_clades = dataset["weekly_clades"]
95 | rate_coef = result["median"]["rate_coef"].clone().requires_grad_()
96 |
97 | cond_data = result["median"].copy()
98 | cond_data.pop("rate")
99 | cond_data.pop("rate_coef")
100 | model = poutine.condition(mutrans.model, cond_data)
101 |
102 | def log_prob(rate_coef):
103 | with poutine.trace() as tr:
104 | with poutine.condition(data={"rate_coef": rate_coef}):
105 | model(weekly_clades, features)
106 | return tr.trace.log_prob_sum()
107 |
108 | hessian = torch.autograd.functional.hessian(
109 | log_prob,
110 | rate_coef,
111 | create_graph=False,
112 | strict=True,
113 | )
114 |
115 | result = {
116 | "args": args,
117 | "mutations": dataset["mutations"],
118 | "initial_ranks": result,
119 | "mean": result["mean"],
120 | "hessian": hessian,
121 | }
122 |
123 | logger.info("Computing covariance")
124 | result["cov"] = _sym_inverse(-hessian)
125 | result["var"] = result["cov"].diag()
126 | result["std"] = result["var"].sqrt()
127 | sigma = result["mean"] / result["std"]
128 | result["ranks"] = sigma.sort(0, descending=True).indices
129 | return result
130 |
131 |
132 | def _sym_inverse(mat):
133 | eye = torch.eye(len(mat))
134 | e = None
135 | for exponent in [-math.inf] + list(range(-20, 1)):
136 | eps = 10**exponent
137 | try:
138 | u = torch.cholesky(eye * eps + mat)
139 | except RuntimeError as e: # noqa F841
140 | continue
141 | logger.info(f"Added {eps:g} to Hessian diagonal")
142 | return torch.cholesky_inverse(u)
143 | raise e from None
144 |
145 |
146 | def _fit_map_filename(args, dataset, cond_data, guide=None, without_feature=None):
147 | return f"results/rank_mutations.{guide is None}.{without_feature}.pt"
148 |
149 |
150 | @cached(_fit_map_filename)
151 | def fit_map(args, dataset, cond_data, guide=None, without_feature=None):
152 | if without_feature is not None:
153 | # Drop feature.
154 | dataset = dataset.copy()
155 | dataset["features"] = dataset["features"].clone()
156 | dataset["features"][:, without_feature] = 0
157 |
158 | # Condition model.
159 | cond_data = {k: torch.as_tensor(v) for k, v in cond_data.items()}
160 | model = poutine.condition(mutrans.model, cond_data)
161 |
162 | # Fit.
163 | result = mutrans.fit_map(
164 | dataset,
165 | model,
166 | guide,
167 | learning_rate=args.map_learning_rate,
168 | num_steps=args.map_num_steps,
169 | log_every=args.log_every,
170 | seed=args.seed,
171 | )
172 |
173 | result["args"] = args
174 | result["guide"] = guide
175 | if without_feature is None:
176 | result["mutation"] = None
177 | else:
178 | result["mutation"] = dataset["mutations"][without_feature]
179 | return result
180 |
181 |
182 | def rank_map(args, dataset, initial_ranks):
183 | """
184 | Given an initial approximate ranking of features, compute MAP log
185 | likelihood ratios of the most significant features.
186 | """
187 | # Fit an initial model for warm-starting.
188 | cond_data = initial_ranks["cond_data"]
189 | if args.warm_start:
190 | guide = fit_map(args, dataset, cond_data)["guide"]
191 | else:
192 | guide = None
193 |
194 | # Evaluate on the null hypothesis + the most positive features.
195 | dropouts = {}
196 | for feature in [None] + initial_ranks["ranks"].tolist():
197 | dropouts[feature] = fit_map(args, dataset, cond_data, guide, feature)
198 |
199 | result = {
200 | "args": args,
201 | "mutations": dataset["mutations"],
202 | "initial_ranks": initial_ranks,
203 | "dropouts": dropouts,
204 | }
205 | logger.info("saving results/rank_mutations.pt")
206 | torch.save(result, "results/rank_mutations.pt")
207 |
208 |
209 | def main(args):
210 | if args.double:
211 | torch.set_default_dtype(torch.double)
212 | if args.cuda:
213 | torch.set_default_tensor_type(
214 | torch.cuda.DoubleTensor if args.double else torch.cuda.FloatTensor
215 | )
216 |
217 | dataset = load_data(args)
218 | if args.full:
219 | initial_ranks = rank_full_svi(args, dataset)
220 | else:
221 | initial_ranks = rank_mf_svi(args, dataset)
222 | if args.hessian:
223 | compute_hessian(args, dataset, initial_ranks)
224 | if args.dropout:
225 | rank_map(args, dataset, initial_ranks)
226 |
227 |
228 | if __name__ == "__main__":
229 | parser = argparse.ArgumentParser(
230 | description="Rank mutations via SVI and leave-feature-out MAP"
231 | )
232 | parser.add_argument("--full", action="store_true")
233 | parser.add_argument("--full-num-steps", default=10001, type=int)
234 | parser.add_argument("--full-learning-rate", default=0.01, type=float)
235 | parser.add_argument("--full-learning-rate-decay", default=0.01, type=float)
236 | parser.add_argument("--svi-num-steps", default=1001, type=int)
237 | parser.add_argument("--svi-learning-rate", default=0.05, type=float)
238 | parser.add_argument("--map-num-steps", default=1001, type=int)
239 | parser.add_argument("--map-learning-rate", default=0.05, type=float)
240 | parser.add_argument("--dropout", action="store_true")
241 | parser.add_argument("--hessian", action="store_true")
242 | parser.add_argument("--warm-start", action="store_true")
243 | parser.add_argument("--double", action="store_true", default=True)
244 | parser.add_argument("--single", action="store_false", dest="double")
245 | parser.add_argument(
246 | "--cuda", action="store_true", default=torch.cuda.is_available()
247 | )
248 | parser.add_argument("--cpu", dest="cuda", action="store_false")
249 | parser.add_argument("--seed", default=20210319, type=int)
250 | parser.add_argument("-f", "--force", action="store_true")
251 | parser.add_argument("-l", "--log-every", default=100, type=int)
252 | args = parser.parse_args()
253 | args.device = "cuda" if args.cuda else "cpu"
254 | main(args)
255 |
--------------------------------------------------------------------------------
/scripts/run_backtesting.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | backtesting_days=$(seq -s, 150 14 550)
4 |
5 | python mutrans.py --backtesting-max-day $backtesting_days --forecast-steps 12
6 |
--------------------------------------------------------------------------------
/scripts/update_headers.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import argparse
5 | import glob
6 | import os
7 | import sys
8 |
9 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10 | blacklist = ["/build/", "/dist/", "/pyrocov/external/"]
11 | file_types = [
12 | ("*.py", "# {}"),
13 | ("*.cpp", "// {}"),
14 | ]
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--check", action="store_true")
18 | args = parser.parse_args()
19 | dirty = []
20 |
21 | for basename, comment in file_types:
22 | copyright_line = comment.format("Copyright Contributors to the Pyro-Cov project.\n")
23 | # See https://spdx.org/ids-how
24 | spdx_line = comment.format("SPDX-License-Identifier: Apache-2.0\n")
25 |
26 | filenames = glob.glob(os.path.join(root, "**", basename), recursive=True)
27 | filenames.sort()
28 | filenames = [
29 | filename
30 | for filename in filenames
31 | if not any(word in filename for word in blacklist)
32 | ]
33 | for filename in filenames:
34 | with open(filename) as f:
35 | lines = f.readlines()
36 |
37 | # Ignore empty files like __init__.py
38 | if all(line.isspace() for line in lines):
39 | continue
40 |
41 | # Ensure first few line are copyright notices.
42 | changed = False
43 | lineno = 0
44 | if not lines[lineno].startswith(comment.format("Copyright")):
45 | lines.insert(lineno, copyright_line)
46 | changed = True
47 | lineno += 1
48 | while lines[lineno].startswith(comment.format("Copyright")):
49 | lineno += 1
50 |
51 | # Ensure next line is an SPDX short identifier.
52 | if not lines[lineno].startswith(comment.format("SPDX-License-Identifier")):
53 | lines.insert(lineno, spdx_line)
54 | changed = True
55 | lineno += 1
56 |
57 | # Ensure next line is blank.
58 | if not lines[lineno].isspace():
59 | lines.insert(lineno, "\n")
60 | changed = True
61 |
62 | if not changed:
63 | continue
64 |
65 | if args.check:
66 | dirty.append(filename)
67 | continue
68 |
69 | with open(filename, "w") as f:
70 | f.write("".join(lines))
71 |
72 | print("updated {}".format(filename[len(root) + 1 :]))
73 |
74 | if dirty:
75 | print("The following files need license headers:\n{}".format("\n".join(dirty)))
76 | print("Please run 'make license'")
77 | sys.exit(1)
78 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = E741,E203,W503
4 | per_file_ignores =
5 | pyrocov/io.py:E226
6 | exclude =
7 | generate_epiToPublicAndDate.py
8 |
9 | [isort]
10 | profile = black
11 | skip_glob = .ipynb_checkpoints
12 | known_first_party = pyrocov
13 | known_third_party = opt_einsum, pyro, torch, torchvision
14 |
15 | [tool:pytest]
16 | filterwarnings = error
17 | ignore::PendingDeprecationWarning
18 | ignore::DeprecationWarning
19 | once::DeprecationWarning
20 | ignore:Failed to find .* pangolin aliases may be stale:RuntimeWarning
21 |
22 | [mypy]
23 | ignore_missing_imports = True
24 | allow_redefinition = True
25 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import re
5 | import sys
6 |
7 | from setuptools import find_packages, setup
8 |
9 | with open("pyrocov/__init__.py") as f:
10 | for line in f:
11 | match = re.match('^__version__ = "(.*)"$', line)
12 | if match:
13 | __version__ = match.group(1)
14 | break
15 |
16 | try:
17 | long_description = open("README.md", encoding="utf-8").read()
18 | except Exception as e:
19 | sys.stderr.write("Failed to read README.md: {}\n".format(e))
20 | sys.stderr.flush()
21 | long_description = ""
22 |
23 | setup(
24 | name="pyrocov",
25 | version="0.1.0",
26 | description="Pyro tools for Sars-CoV-2 analysis",
27 | long_description=long_description,
28 | long_description_content_type="text/markdown",
29 | packages=find_packages(include=["pyrocov"]),
30 | url="http://pyro.ai",
31 | author="Pyro team at the Broad Institute of MIT and Harvard",
32 | author_email="fobermey@broadinstitute.org",
33 | install_requires=[
34 | "biopython>=1.54",
35 | "pyro-ppl>=1.7",
36 | "geopy",
37 | "gpytorch",
38 | "scikit-learn",
39 | "umap-learn",
40 | "mappy",
41 | "protobuf>=3.12,<3.13", # pinned by usher
42 | "tqdm",
43 | "colorcet",
44 | ],
45 | extras_require={
46 | "test": [
47 | "black",
48 | "isort>=5.0",
49 | "flake8",
50 | "pytest>=5.0",
51 | "pytest-xdist",
52 | "mypy>=0.812",
53 | "types-protobuf",
54 | ],
55 | },
56 | python_requires=">=3.6",
57 | keywords="pyro pytorch phylogenetic machine learning",
58 | license="Apache 2.0",
59 | classifiers=[
60 | "Intended Audience :: Developers",
61 | "Intended Audience :: Science/Research",
62 | "License :: OSI Approved :: Apache Software License",
63 | "Operating System :: POSIX :: Linux",
64 | "Operating System :: MacOS :: MacOS X",
65 | "Programming Language :: Python :: 3.6",
66 | "Programming Language :: Python :: 3.7",
67 | ],
68 | )
69 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2017-2019 Uber Technologies, Inc.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import logging
5 | import os
6 |
7 | # create log handler for tests
8 | level = logging.INFO if "CI" in os.environ else logging.DEBUG
9 | logging.basicConfig(format="%(levelname).1s \t %(message)s", level=level)
10 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import logging
5 | import os
6 |
7 | import pyro
8 |
9 | level = logging.INFO if "CI" in os.environ else logging.DEBUG
10 | logging.basicConfig(format="%(levelname).1s \t %(message)s", level=level)
11 |
12 |
13 | def pytest_runtest_setup(item):
14 | pyro.clear_param_store()
15 | pyro.enable_validation(True)
16 | pyro.set_rng_seed(0)
17 |
--------------------------------------------------------------------------------
/test/test_distributions.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pytest
5 | import torch
6 | from pyro.distributions.testing.gof import auto_goodness_of_fit
7 |
8 | from pyrocov.distributions import SoftLaplace
9 |
10 | TEST_FAILURE_RATE = 1e-2
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "Dist, params",
15 | [
16 | (SoftLaplace, {"loc": 0.0, "scale": 1.0}),
17 | (SoftLaplace, {"loc": 1.0, "scale": 1.0}),
18 | (SoftLaplace, {"loc": 0.0, "scale": 10.0}),
19 | ],
20 | )
21 | def test_gof(Dist, params):
22 | num_samples = 50000
23 | d = Dist(**params)
24 | samples = d.sample(torch.Size([num_samples]))
25 | probs = d.log_prob(samples).exp()
26 |
27 | # Test each batch independently.
28 | probs = probs.reshape(num_samples, -1)
29 | samples = samples.reshape(probs.shape + d.event_shape)
30 | for b in range(probs.size(-1)):
31 | gof = auto_goodness_of_fit(samples[:, b], probs[:, b])
32 | assert gof > TEST_FAILURE_RATE
33 |
--------------------------------------------------------------------------------
/test/test_io.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import glob
5 | import os
6 |
7 | import pytest
8 | import torch
9 | from Bio import Phylo
10 |
11 | from pyrocov.io import read_alignment, read_nexus_trees, stack_nexus_trees
12 |
13 | ROOT = os.path.dirname(os.path.dirname(__file__))
14 | FILENAME = os.path.join(ROOT, "data", "GTR4G_posterior.trees")
15 |
16 |
17 | @pytest.mark.skipif(not os.path.exists(FILENAME), reason="file unavailable")
18 | @pytest.mark.xfail(reason="Python <3.8 cannot .read() large files", run=False)
19 | def test_bio_phylo_parse():
20 | trees = Phylo.parse(FILENAME, format="nexus")
21 | for tree in trees:
22 | print(tree.count_terminals())
23 |
24 |
25 | @pytest.mark.skipif(not os.path.exists(FILENAME), reason="file unavailable")
26 | @pytest.mark.parametrize("processes", [0, 2])
27 | def test_read_nexus_trees(processes):
28 | trees = read_nexus_trees(FILENAME, max_num_trees=5, processes=processes)
29 | trees = list(trees)
30 | assert len(trees) == 5
31 | for tree in trees:
32 | assert tree.count_terminals() == 772
33 |
34 |
35 | @pytest.mark.skipif(not os.path.exists(FILENAME), reason="file unavailable")
36 | @pytest.mark.parametrize("processes", [0, 2])
37 | def test_stack_nexus_trees(processes):
38 | phylo = stack_nexus_trees(FILENAME, max_num_trees=5, processes=processes)
39 | assert phylo.batch_shape == (5,)
40 |
41 |
42 | @pytest.mark.parametrize("filename", glob.glob("data/treebase/DS*.nex"))
43 | def test_read_alignment(filename):
44 | probs = read_alignment(filename)
45 | assert probs.dim() == 3
46 | assert torch.isfinite(probs).all()
47 |
--------------------------------------------------------------------------------
/test/test_markov_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pyro.distributions as dist
5 | import pytest
6 | import torch
7 |
8 | from pyrocov.markov_tree import MarkovTree, _interpolate_lmve, _mpm
9 | from pyrocov.phylo import Phylogeny
10 |
11 |
12 | def grad(output, inputs, **kwargs):
13 | if not output.requires_grad:
14 | return list(map(torch.zeros_like, inputs))
15 | return torch.autograd.grad(output, inputs, **kwargs)
16 |
17 |
18 | @pytest.mark.parametrize("size", range(2, 10))
19 | def test_mpm(size):
20 | matrix = torch.randn(size, size).exp()
21 | matrix /= matrix.sum(dim=-1, keepdim=True) # Make stochastic.
22 | matrix = (matrix + 4 * torch.eye(size)) / 5 # Make diagonally dominant.
23 | vector = torch.randn(size)
24 |
25 | for t in range(0, 10):
26 | expected = vector @ matrix.matrix_power(t)
27 | actual = _mpm(matrix, torch.tensor(float(t)), vector)
28 |
29 | assert actual.shape == expected.shape
30 | assert torch.allclose(actual, expected)
31 |
32 |
33 | @pytest.mark.parametrize("size", [2, 3, 4, 5])
34 | @pytest.mark.parametrize("duration", [1, 2, 3, 4, 5])
35 | def test_interpolate_lmve_smoke(size, duration):
36 | matrix = torch.randn(duration, size, size).exp()
37 | log_vector = torch.randn(size)
38 | t0 = -0.6
39 | while t0 < duration + 0.9:
40 | t1 = t0 + 0.2
41 | while t1 < duration + 0.9:
42 | actual = _interpolate_lmve(
43 | torch.tensor(t0), torch.tensor(t1), matrix, log_vector
44 | )
45 | assert actual.shape == log_vector.shape
46 | t1 += 1
47 | t0 += 1
48 |
49 |
50 | @pytest.mark.parametrize("num_states", [3, 7])
51 | @pytest.mark.parametrize("num_leaves", [4, 16, 17])
52 | @pytest.mark.parametrize("duration", [1, 5])
53 | @pytest.mark.parametrize("num_samples", [1, 2, 3])
54 | def test_markov_tree_log_prob(num_samples, duration, num_leaves, num_states):
55 | phylo = Phylogeny.generate(num_leaves, num_samples=num_samples)
56 | phylo.times.mul_(duration * 0.25).add_(0.75 * duration)
57 | phylo.times.round_() # Required for naive-vs-likelihood agreement.
58 |
59 | leaf_state = dist.Categorical(torch.ones(num_states)).sample([num_leaves])
60 |
61 | state_trans = torch.randn(duration, num_states, num_states).mul(0.1).exp()
62 | state_trans /= state_trans.sum(dim=-1, keepdim=True)
63 | state_trans += 4 * torch.eye(num_states)
64 | state_trans /= state_trans.sum(dim=-1, keepdim=True)
65 | state_trans.requires_grad_()
66 |
67 | dist1 = MarkovTree(phylo, state_trans, method="naive")
68 | dist2 = MarkovTree(phylo, state_trans, method="likelihood")
69 |
70 | logp1 = dist1.log_prob(leaf_state)
71 | logp2 = dist2.log_prob(leaf_state)
72 | assert torch.allclose(logp1, logp2)
73 |
74 | grad1 = grad(logp1.logsumexp(0), [state_trans], allow_unused=True)[0]
75 | grad2 = grad(logp2.logsumexp(0), [state_trans], allow_unused=True)[0]
76 | grad1 = grad1 - grad1.mean(dim=-1, keepdim=True)
77 | grad2 = grad2 - grad2.mean(dim=-1, keepdim=True)
78 | assert torch.allclose(grad1, grad2, rtol=1e-4, atol=1e-4)
79 |
--------------------------------------------------------------------------------
/test/test_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pyro.distributions as dist
5 | import pytest
6 | import torch
7 | from torch.autograd import grad
8 |
9 | from pyrocov.ops import (
10 | logistic_logsumexp,
11 | sparse_multinomial_likelihood,
12 | sparse_poisson_likelihood,
13 | )
14 |
15 |
16 | @pytest.mark.parametrize("T,P,S", [(5, 6, 7)])
17 | @pytest.mark.parametrize("backend", ["sequential"])
18 | def test_logistic_logsumexp(T, P, S, backend):
19 | alpha = torch.randn(P, S, requires_grad=True)
20 | beta = torch.randn(P, S, requires_grad=True)
21 | delta = torch.randn(P, S, requires_grad=True)
22 | tau = torch.randn(T, P)
23 |
24 | expected = logistic_logsumexp(alpha, beta, delta, tau, backend="naive")
25 | actual = logistic_logsumexp(alpha, beta, delta, tau, backend=backend)
26 | assert torch.allclose(actual, expected)
27 |
28 | probe = torch.randn(expected.shape)
29 | expected_grads = grad((probe * expected).sum(), [alpha, beta, delta])
30 | actual_grads = grad((probe * actual).sum(), [alpha, beta, delta])
31 | for e, a, name in zip(expected_grads, actual_grads, ["alpha", "beta", "delta"]):
32 | assert torch.allclose(a, e), name
33 |
34 |
35 | @pytest.mark.parametrize("T,P,S", [(2, 3, 4), (5, 6, 7), (8, 9, 10)])
36 | def test_sparse_poisson_likelihood(T, P, S):
37 | log_rate = torch.randn(T, P, S)
38 | d = dist.Poisson(log_rate.exp())
39 | value = d.sample()
40 | assert 0.1 < (value == 0).float().mean() < 0.9, "weak test"
41 | expected = d.log_prob(value).sum()
42 |
43 | full_log_rate = log_rate.logsumexp(-1)
44 | nnz = value.nonzero(as_tuple=True)
45 | nonzero_value = value[nnz]
46 | nonzero_log_rate = log_rate[nnz]
47 | actual = sparse_poisson_likelihood(full_log_rate, nonzero_log_rate, nonzero_value)
48 | assert torch.allclose(actual, expected)
49 |
50 |
51 | @pytest.mark.parametrize("T,P,S", [(2, 3, 4), (5, 6, 7), (8, 9, 10)])
52 | def test_sparse_multinomial_likelihood(T, P, S):
53 | logits = torch.randn(T, P, S)
54 | value = dist.Poisson(logits.exp()).sample()
55 |
56 | d = dist.Multinomial(logits=logits, validate_args=False)
57 | assert 0.1 < (value == 0).float().mean() < 0.9, "weak test"
58 | expected = d.log_prob(value).sum()
59 |
60 | logits = logits.log_softmax(-1)
61 | total_count = value.sum(-1)
62 | nnz = value.nonzero(as_tuple=True)
63 | nonzero_value = value[nnz]
64 | nonzero_logits = logits[nnz]
65 | actual = sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value)
66 | assert torch.allclose(actual, expected)
67 |
--------------------------------------------------------------------------------
/test/test_phylo.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pytest
5 | import torch
6 |
7 | from pyrocov.phylo import Phylogeny
8 |
9 |
10 | @pytest.mark.parametrize("num_leaves", range(1, 50))
11 | def test_smoke(num_leaves):
12 | phylo = Phylogeny.generate(num_leaves)
13 | phylo.num_lineages()
14 | phylo.hash_topology()
15 | phylo.time_mrca()
16 |
17 |
18 | @pytest.mark.parametrize("num_leaves", range(1, 10))
19 | @pytest.mark.parametrize("num_samples", range(1, 5))
20 | def test_smoke_batch(num_leaves, num_samples):
21 | phylo = Phylogeny.generate(num_leaves, num_samples=num_samples)
22 | phylo.num_lineages()
23 | phylo.hash_topology()
24 | phylo.time_mrca()
25 |
26 |
27 | def test_time_mrca():
28 | # 0 0
29 | # 1 \ 1
30 | # /| 2 2
31 | # 3 4 |\ 3
32 | # 5 \ 4
33 | # / \ 6 5
34 | # 7 8 6
35 | times = torch.tensor([0.0, 1.0, 2.0, 3.0, 3.0, 4.0, 5.0, 6.0, 6.0])
36 | parents = torch.tensor([-1, 0, 0, 1, 1, 2, 2, 5, 5])
37 | leaves = torch.tensor([3, 4, 6, 7, 8])
38 | phylo = Phylogeny(times, parents, leaves)
39 |
40 | actual = phylo.time_mrca()
41 | expected = torch.tensor(
42 | [
43 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
44 | [0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
45 | [0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 2.0, 2.0, 2.0],
46 | [0.0, 1.0, 0.0, 3.0, 1.0, 0.0, 0.0, 0.0, 0.0],
47 | [0.0, 1.0, 0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0],
48 | [0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 2.0, 4.0, 4.0],
49 | [0.0, 0.0, 2.0, 0.0, 0.0, 2.0, 5.0, 2.0, 2.0],
50 | [0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 2.0, 6.0, 4.0],
51 | [0.0, 0.0, 2.0, 0.0, 0.0, 4.0, 2.0, 4.0, 6.0],
52 | ]
53 | )
54 | assert (actual == expected).all()
55 |
56 | actual = phylo.leaf_time_mrca()
57 | expected = torch.tensor(
58 | [
59 | [3.0, 1.0, 0.0, 0.0, 0.0],
60 | [1.0, 3.0, 0.0, 0.0, 0.0],
61 | [0.0, 0.0, 5.0, 2.0, 2.0],
62 | [0.0, 0.0, 2.0, 6.0, 4.0],
63 | [0.0, 0.0, 2.0, 4.0, 6.0],
64 | ]
65 | )
66 | assert (actual == expected).all()
67 |
--------------------------------------------------------------------------------
/test/test_sarscov2.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | from pyrocov.sarscov2 import nuc_mutations_to_aa_mutations
5 |
6 |
7 | def test_nuc_to_aa():
8 | assert nuc_mutations_to_aa_mutations(["A23403G"]) == ["S:D614G"]
9 |
--------------------------------------------------------------------------------
/test/test_sketch.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import re
5 |
6 | import pyro.distributions as dist
7 | import pytest
8 | import torch
9 |
10 | from pyrocov.sketch import AMSSketcher, ClockSketcher, KmerCounter
11 |
12 |
13 | def random_string(size):
14 | probs = torch.tensor([1.0, 1.0, 1.0, 1.0, 0.05])
15 | probs /= probs.sum()
16 | string = "".join("ACGTN"[i] for i in dist.Categorical(probs).sample([size]))
17 | return string
18 |
19 |
20 | def test_kmer_counter():
21 | string = random_string(10000)
22 |
23 | results = {}
24 | for backend in ["python", "cpp"]:
25 | results[backend] = KmerCounter(backend=backend)
26 | for part in re.findall("[ACTG]+", string):
27 | results[backend].update(part)
28 | results[backend].flush()
29 |
30 | expected = results["python"]
31 | actual = results["cpp"]
32 | assert actual == expected
33 |
34 |
35 | @pytest.mark.parametrize("min_k,max_k", [(2, 2), (2, 4), (3, 12)])
36 | @pytest.mark.parametrize("bits", [16])
37 | def test_string_to_soft_hash(min_k, max_k, bits):
38 | string = random_string(1000)
39 |
40 | results = {}
41 | for backend in ["python", "cpp"]:
42 | results[backend] = torch.empty(64)
43 | sketcher = AMSSketcher(min_k=min_k, max_k=max_k, bits=bits, backend=backend)
44 | sketcher.string_to_soft_hash(string, results[backend])
45 |
46 | expected = results["python"]
47 | actual = results["cpp"]
48 | tol = expected.std().item() * 1e-6
49 | assert (actual - expected).abs().max().item() < tol
50 |
51 |
52 | @pytest.mark.parametrize("k", [2, 3, 4, 5, 8, 16, 32])
53 | def test_string_to_clock_hash(k):
54 | string = random_string(1000)
55 |
56 | results = {}
57 | for backend in ["python", "cpp"]:
58 | sketcher = ClockSketcher(k, backend=backend)
59 | results[backend] = sketcher.init_sketch()
60 | sketcher.string_to_hash(string, results[backend])
61 |
62 | expected = results["python"]
63 | actual = results["cpp"]
64 | assert (actual.clocks == expected.clocks).all()
65 | assert (actual.count == expected.count).all()
66 |
67 |
68 | @pytest.mark.parametrize("k", [2, 3, 4, 5, 8, 16, 32])
69 | @pytest.mark.parametrize("size", [20000])
70 | def test_clock_cdiff(k, size):
71 | n = 10
72 | strings = [random_string(size) for _ in range(n)]
73 | sketcher = ClockSketcher(k)
74 | sketch = sketcher.init_sketch(n)
75 | for i, string in enumerate(strings):
76 | sketcher.string_to_hash(string, sketch[i])
77 |
78 | cdiff = sketcher.cdiff(sketch, sketch)
79 | assert cdiff.shape == (n, n)
80 | assert (cdiff.clocks.transpose(0, 1) == -cdiff.clocks).all()
81 | assert (cdiff.clocks.diagonal(dim1=0, dim2=1) == 0).all()
82 | assert (cdiff.count.diagonal(dim1=0, dim2=1) == 0).all()
83 | mask = torch.arange(n) < torch.arange(n).unsqueeze(-1)
84 | mean = cdiff.clocks[mask].abs().float().mean().item()
85 | assert mean > 64 - 10, mean
86 |
--------------------------------------------------------------------------------
/test/test_softmax_tree.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import math
5 |
6 | import pyro
7 | import pyro.distributions as dist
8 | import pytest
9 | import torch
10 | from pyro.infer import SVI, Trace_ELBO
11 | from pyro.optim import Adam
12 |
13 | from pyrocov.markov_tree import MarkovTree
14 | from pyrocov.phylo import Phylogeny
15 | from pyrocov.softmax_tree import SoftmaxTree
16 |
17 |
18 | @pytest.mark.parametrize("num_bits", [2, 3, 4, 5, 10])
19 | @pytest.mark.parametrize("num_leaves", [2, 3, 4, 5, 10])
20 | def test_rsample(num_leaves, num_bits):
21 | phylo = Phylogeny.generate(num_leaves)
22 | leaf_times = phylo.times[phylo.leaves]
23 | bit_times = torch.randn(num_bits)
24 | logits = torch.randn(num_leaves, num_bits)
25 | tree = SoftmaxTree(leaf_times, bit_times, logits)
26 | value = tree.rsample()
27 | tree.log_prob(value)
28 |
29 |
30 | def model(leaf_times, leaf_states, num_features):
31 | assert len(leaf_times) == len(leaf_states)
32 |
33 | # Timed tree concerning reproductive behavior only.
34 | coal_params = pyro.sample("coal_params", dist.CoalParamPrior("TODO")) # global
35 | # Note this is where our coalescent model assumes geographically
36 | # homogeneous reproductive rate, which is not realistic.
37 | # See appendix of (Vaughan et al. 2014) for discussion of this assumption.
38 | phylogeny = pyro.sample("phylogeny", dist.Coalescent(coal_params, leaf_times))
39 |
40 | # This is compatible with subsampling features, but not leaves.
41 | subs_params = pyro.sample("subs_params", dist.GTRGamma("TODO")) # global
42 | with pyro.plate("features", num_features, leaf_states.size(-1)):
43 | # This is similar to the phylogeographic likelihood in the pyro-cov repo.
44 | # This is simpler (because it is time-homogeneous)
45 | # but more complex in that it is batched.
46 | # This computes mutation likelihood via dynamic programming.
47 | pyro.sample("leaf_times", MarkovTree(phylogeny, subs_params), obs=leaf_states)
48 |
49 |
50 | def guide(leaf_times, leaf_states, num_features, *, logits_fn=None):
51 | assert len(leaf_times) == len(leaf_states)
52 |
53 | # Sample all continuous latents in a giant correlated auxiliary.
54 | aux = pyro.sample("aux", dist.LowRankMultivariateNormal("TODO"))
55 | # Split it up (TODO switch to EasyGuide).
56 | pyro.sample("coal_params", dist.Delta(aux["TODO"])) # global
57 | pyro.sample("subs_params", dist.Delta(aux["TODO"])) # global
58 | # These are the times of each bit in the embedding vector.
59 | bit_times = pyro.sample(
60 | "bit_times", dist.Delta(aux["TODO"]), infer={"is_auxiliary": True}
61 | )
62 |
63 | # Learn parameters of the discrete distributions,
64 | # possibly conditioned on continuous latents.
65 | if logits_fn is not None:
66 | # Amortized guide, compatible with subsampling leaves but not features.
67 | logits = logits_fn(leaf_states, leaf_times) # batched over leaves
68 | else:
69 | # Fully local guide, compatible with subsampling features but not leaves.
70 | with pyro.plate("leaves", len(leaf_times)):
71 | logits = pyro.param(
72 | "logits", lambda: torch.randn(leaf_times.shape), event_dim=0
73 | )
74 | assert len(logits) == len(leaf_times)
75 |
76 | pyro.sample("phylogeny", SoftmaxTree(bit_times, logits))
77 |
78 |
79 | @pytest.mark.xfail(reason="WIP")
80 | @pytest.mark.parametrize("num_features", [4])
81 | @pytest.mark.parametrize("num_leaves", [2, 3, 4, 5, 10, 100])
82 | def test_svi(num_leaves, num_features):
83 | phylo = Phylogeny.generate(num_leaves)
84 | leaf_times = phylo.times[phylo.leaves]
85 | leaf_states = torch.full((num_leaves, num_features), 0.5).bernoulli()
86 |
87 | svi = SVI(model, guide, Adam({"lr": 1e-4}), Trace_ELBO())
88 | for i in range(2):
89 | loss = svi.step(leaf_times, leaf_states)
90 | assert math.isfinite(loss)
91 |
--------------------------------------------------------------------------------
/test/test_strains.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pytest
5 |
6 | from pyrocov.strains import TimeSpaceStrainModel, simulate
7 |
8 |
9 | @pytest.mark.parametrize("T", [32])
10 | @pytest.mark.parametrize("R", [6])
11 | @pytest.mark.parametrize("S", [5])
12 | def test_strains(T, R, S):
13 | dataset = simulate(T, R, S)
14 | model = TimeSpaceStrainModel(**dataset)
15 | model.fit(num_steps=101, haar=False)
16 | model.median()
17 |
--------------------------------------------------------------------------------
/test/test_substitution.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import pyro.poutine as poutine
5 | import pytest
6 | import torch
7 | from pyro.infer.autoguide import AutoDelta
8 |
9 | from pyrocov.substitution import GeneralizedTimeReversible, JukesCantor69
10 |
11 |
12 | @pytest.mark.parametrize("Model", [JukesCantor69, GeneralizedTimeReversible])
13 | def test_matrix_exp(Model):
14 | model = Model()
15 | guide = AutoDelta(model)
16 | guide()
17 | trace = poutine.trace(guide).get_trace()
18 | t = torch.randn(10).exp()
19 | with poutine.replay(trace=trace):
20 | m = model()
21 | assert torch.allclose(model(), m)
22 |
23 | exp_mt = (m * t[:, None, None]).matrix_exp()
24 | actual = model.matrix_exp(t)
25 | assert torch.allclose(actual, exp_mt, atol=1e-6)
26 |
27 | actual = model.log_matrix_exp(t)
28 | log_exp_mt = exp_mt.log()
29 | assert torch.allclose(actual, log_exp_mt, atol=1e-6)
30 |
--------------------------------------------------------------------------------
/test/test_usher.py:
--------------------------------------------------------------------------------
1 | # Copyright Contributors to the Pyro-Cov project.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | import math
5 | import os
6 | import random
7 | import tempfile
8 | from collections import defaultdict
9 |
10 | from pyrocov.align import PANGOLEARN_DATA
11 | from pyrocov.usher import prune_mutation_tree, refine_mutation_tree
12 |
13 |
14 | def test_refine_prune():
15 | filename1 = os.path.join(PANGOLEARN_DATA, "lineageTree.pb")
16 | with tempfile.TemporaryDirectory() as tmpdirname:
17 | filename2 = os.path.join(tmpdirname, "refinedTree.pb")
18 | filename3 = os.path.join(tmpdirname, "prunedTree.pb")
19 |
20 | # Refine the tree.
21 | fine_to_coarse = refine_mutation_tree(filename1, filename2)
22 |
23 | # Find canonical fine names for each coarse name.
24 | coarse_to_fine = defaultdict(list)
25 | for fine, coarse in fine_to_coarse.items():
26 | coarse_to_fine[coarse].append(fine)
27 | coarse_to_fine = {k: min(vs) for k, vs in coarse_to_fine.items()}
28 |
29 | # Prune the tree, keeping coarse nodes.
30 | weights = {fine: random.lognormvariate(0, 1) for fine in fine_to_coarse}
31 | for fine in coarse_to_fine.values():
32 | weights[fine] = math.inf
33 | prune_mutation_tree(filename2, filename3, weights=weights, max_num_nodes=10000)
34 |
--------------------------------------------------------------------------------