├── docs ├── faq.md ├── tutorials │ └── overview.md ├── api │ ├── model │ │ ├── VIPRS.md │ │ ├── VIPRSMix.md │ │ ├── LDPredInf.md │ │ ├── BayesPRSModel.md │ │ └── gridsearch │ │ │ ├── VIPRSGrid.md │ │ │ ├── grid_utils.md │ │ │ ├── HyperparameterGrid.md │ │ │ └── HyperparameterSearch.md │ ├── plot │ │ └── diagnostics.md │ ├── utils │ │ ├── data_utils.md │ │ ├── exceptions.md │ │ ├── OptimizeResult.md │ │ └── compute_utils.md │ ├── eval │ │ ├── binary_metrics.md │ │ ├── pseudo_metrics.md │ │ └── continuous_metrics.md │ └── overview.md ├── citation.md ├── commandline │ ├── overview.md │ ├── viprs_score.md │ ├── viprs_evaluate.md │ └── viprs_fit.md ├── index.md ├── installation.md ├── download_ld.md └── getting_started.md ├── viprs ├── plot │ ├── __init__.py │ └── diagnostics.py ├── utils │ ├── __init__.py │ ├── exceptions.py │ ├── data_utils.py │ ├── math_utils.pxd │ ├── OptimizeResult.py │ ├── math_utils.pyx │ └── compute_utils.py ├── model │ ├── vi │ │ ├── __init__.py │ │ ├── e_step_cpp.pxd │ │ └── e_step_cpp.pyx │ ├── __init__.py │ ├── gridsearch │ │ ├── __init__.py │ │ ├── grid_utils.py │ │ ├── VIPRSGrid.py │ │ └── HyperparameterGrid.py │ └── LDPredInf.py ├── .DS_Store ├── eval │ ├── __init__.py │ ├── eval_utils.py │ ├── continuous_metrics.py │ ├── pseudo_metrics.py │ └── binary_metrics.py └── __init__.py ├── requirements-test.txt ├── requirements-optional.txt ├── .DS_Store ├── notebooks ├── .DS_Store └── height_example.ipynb ├── requirements-docs.txt ├── requirements.txt ├── MANIFEST.in ├── pyproject.toml ├── .gitignore ├── CITATION.md ├── .github └── workflows │ ├── ci-docs.yml │ ├── ci-linux.yml │ ├── ci-windows.yml │ ├── ci-osx.yml │ └── wheels.yml ├── Makefile ├── LICENSE ├── containers ├── cli.Dockerfile └── jupyter.Dockerfile ├── tests ├── conda_manual_testing.sh ├── test_cli.sh └── test_basic.py ├── README.md ├── mkdocs.yml ├── bin ├── viprs_score └── viprs_evaluate ├── CHANGELOG.md └── setup.py /docs/faq.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viprs/plot/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /viprs/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/tutorials/overview.md: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest -------------------------------------------------------------------------------- /viprs/model/vi/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/api/model/VIPRS.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.VIPRS -------------------------------------------------------------------------------- /docs/api/model/VIPRSMix.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.VIPRSMix -------------------------------------------------------------------------------- /docs/api/model/LDPredInf.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.LDPredInf 2 | -------------------------------------------------------------------------------- /docs/api/plot/diagnostics.md: -------------------------------------------------------------------------------- 1 | ::: viprs.plot.diagnostics -------------------------------------------------------------------------------- /docs/api/utils/data_utils.md: -------------------------------------------------------------------------------- 1 | ::: viprs.utils.data_utils -------------------------------------------------------------------------------- /docs/api/utils/exceptions.md: -------------------------------------------------------------------------------- 1 | ::: viprs.utils.exceptions -------------------------------------------------------------------------------- /requirements-optional.txt: -------------------------------------------------------------------------------- 1 | scikit-optimize 2 | seaborn 3 | -------------------------------------------------------------------------------- /docs/api/eval/binary_metrics.md: -------------------------------------------------------------------------------- 1 | ::: viprs.eval.binary_metrics -------------------------------------------------------------------------------- /docs/api/model/BayesPRSModel.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.BayesPRSModel -------------------------------------------------------------------------------- /docs/api/utils/OptimizeResult.md: -------------------------------------------------------------------------------- 1 | ::: viprs.utils.OptimizeResult -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shz9/viprs/HEAD/.DS_Store -------------------------------------------------------------------------------- /docs/api/eval/pseudo_metrics.md: -------------------------------------------------------------------------------- 1 | ::: viprs.eval.pseudo_metrics 2 | -------------------------------------------------------------------------------- /docs/api/utils/compute_utils.md: -------------------------------------------------------------------------------- 1 | ::: viprs.utils.compute_utils 2 | -------------------------------------------------------------------------------- /docs/api/eval/continuous_metrics.md: -------------------------------------------------------------------------------- 1 | ::: viprs.eval.continuous_metrics 2 | -------------------------------------------------------------------------------- /docs/api/model/gridsearch/VIPRSGrid.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.gridsearch.VIPRSGrid -------------------------------------------------------------------------------- /docs/api/model/gridsearch/grid_utils.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.gridsearch.grid_utils -------------------------------------------------------------------------------- /viprs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shz9/viprs/HEAD/viprs/.DS_Store -------------------------------------------------------------------------------- /notebooks/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shz9/viprs/HEAD/notebooks/.DS_Store -------------------------------------------------------------------------------- /docs/api/model/gridsearch/HyperparameterGrid.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.gridsearch.HyperparameterGrid 2 | -------------------------------------------------------------------------------- /docs/api/model/gridsearch/HyperparameterSearch.md: -------------------------------------------------------------------------------- 1 | ::: viprs.model.gridsearch.HyperparameterSearch -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocstrings-python 3 | mkdocs-material 4 | mkdocs-material-extensions -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<2 2 | scipy 3 | pandas 4 | tqdm 5 | magenpy>=0.1.5 6 | statsmodels 7 | scikit-learn 8 | psutil 9 | joblib 10 | -------------------------------------------------------------------------------- /viprs/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .BayesPRSModel import BayesPRSModel 3 | from .VIPRS import VIPRS 4 | from .VIPRSMix import VIPRSMix 5 | from .gridsearch import VIPRSGrid 6 | 7 | -------------------------------------------------------------------------------- /viprs/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | class OptimizationDivergence(Exception): 3 | """ 4 | Exception raised when the optimization algorithm diverges. 5 | """ 6 | pass 7 | -------------------------------------------------------------------------------- /viprs/model/gridsearch/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .HyperparameterGrid import HyperparameterGrid 3 | from .VIPRSGrid import VIPRSGrid 4 | 5 | from .grid_utils import ( 6 | select_best_model, 7 | bayesian_model_average, 8 | ) 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include MANIFEST.in 2 | include requirements*.txt 3 | include LICENSE 4 | include *.md 5 | include setup.py 6 | 7 | graft viprs 8 | 9 | global-exclude *.c 10 | global-exclude *.cpp 11 | global-exclude *.so 12 | global-exclude *.pyd 13 | global-exclude *.pyc 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Minimum requirements for the build system to execute. 3 | requires = [ 4 | "setuptools", 5 | "wheel", 6 | "cython", 7 | "extension-helpers", 8 | "scipy", 9 | "oldest-supported-numpy", 10 | "pkgconfig" 11 | ] 12 | build-backend = "setuptools.build_meta" 13 | 14 | [tool.cibuildwheel] 15 | test-extras = "test" 16 | test-command = "pytest {project}/tests" 17 | # Optional 18 | build-verbosity = 1 19 | 20 | -------------------------------------------------------------------------------- /viprs/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def download_ld_matrix(target_dir='.', chromosome=None): 3 | """ 4 | Download LD matrices for VIPRS software. 5 | 6 | TODO: Update this once data is made available. 7 | 8 | :param target_dir: The path or directory where to store the LD matrix 9 | :param chromosome: An integer or list of integers with the chromosome numbers for which to download 10 | the LD matrices from Zenodo. 11 | """ 12 | 13 | raise NotImplementedError("This function is not yet implemented.") 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | *.py[cod] 3 | **/.ipynb_checkpoints 4 | 5 | # C extensions 6 | *.c 7 | *.cpp 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | site/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | .eggs 28 | 29 | # PyInstaller 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # custom 38 | .idea/ 39 | .vscode/ 40 | .tox/ 41 | .pytest_cache/ 42 | *.html 43 | *.zarr 44 | *.npz 45 | *.log 46 | **/temp 47 | -------------------------------------------------------------------------------- /docs/citation.md: -------------------------------------------------------------------------------- 1 | If you use `viprs` in your research, please cite the following paper(s): 2 | 3 | ```bibtex 4 | @article{ZABAD2023741, 5 | title = {Fast and accurate Bayesian polygenic risk modeling with variational inference}, 6 | journal = {The American Journal of Human Genetics}, 7 | volume = {110}, 8 | number = {5}, 9 | pages = {741-761}, 10 | year = {2023}, 11 | issn = {0002-9297}, 12 | doi = {https://doi.org/10.1016/j.ajhg.2023.03.009}, 13 | url = {https://www.sciencedirect.com/science/article/pii/S0002929723000939}, 14 | author = {Shadi Zabad and Simon Gravel and Yue Li} 15 | } 16 | ``` 17 | -------------------------------------------------------------------------------- /viprs/utils/math_utils.pxd: -------------------------------------------------------------------------------- 1 | from cython cimport floating 2 | 3 | cdef floating[::1] softmax(floating[::1] x) noexcept nogil 4 | cdef floating sigmoid(floating x) noexcept nogil 5 | cdef floating logit(floating x) noexcept nogil 6 | cdef floating dot(floating[::1] v1, floating[::1] v2) noexcept nogil 7 | cdef floating vec_sum(floating[::1] v1) noexcept nogil 8 | cdef void axpy(floating[::1] v1, floating[::1] v2, floating s) noexcept nogil 9 | cdef void scipy_blas_axpy(floating[::1] v1, floating[::1] v2, floating alpha) noexcept nogil 10 | cdef floating scipy_blas_dot(floating[::1] v1, floating[::1] v2) noexcept nogil 11 | 12 | cdef floating[::1] clip_list(floating[::1] a, floating min_value, floating max_value) noexcept nogil 13 | cdef floating c_max(floating[::1] x) noexcept nogil 14 | cdef floating clip(floating a, floating min_value, floating max_value) noexcept nogil 15 | -------------------------------------------------------------------------------- /CITATION.md: -------------------------------------------------------------------------------- 1 | If you use `viprs` in your research, please cite the following paper(s): 2 | 3 | > Zabad, S., Gravel, S., & Li, Y. (2023). **Fast and accurate Bayesian polygenic risk modeling with variational inference.** 4 | The American Journal of Human Genetics, 110(5), 741–761. https://doi.org/10.1016/j.ajhg.2023.03.009 5 | 6 | ## BibTeX records 7 | 8 | ```bibtex 9 | @article{ZABAD2023741, 10 | title = {Fast and accurate Bayesian polygenic risk modeling with variational inference}, 11 | journal = {The American Journal of Human Genetics}, 12 | volume = {110}, 13 | number = {5}, 14 | pages = {741-761}, 15 | year = {2023}, 16 | issn = {0002-9297}, 17 | doi = {https://doi.org/10.1016/j.ajhg.2023.03.009}, 18 | url = {https://www.sciencedirect.com/science/article/pii/S0002929723000939}, 19 | author = {Shadi Zabad and Simon Gravel and Yue Li} 20 | } 21 | ``` 22 | -------------------------------------------------------------------------------- /.github/workflows/ci-docs.yml: -------------------------------------------------------------------------------- 1 | name: Build Docs 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | permissions: 8 | contents: write 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | 14 | - uses: actions/checkout@v4 15 | 16 | - name: Configure Git Credentials 17 | run: | 18 | git config user.name github-actions[bot] 19 | git config user.email 41898282+github-actions[bot]@users.noreply.github.com 20 | 21 | - uses: actions/setup-python@v5 22 | with: 23 | python-version: 3.12 24 | 25 | - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV 26 | 27 | - uses: actions/cache@v4 28 | with: 29 | key: mkdocs-material-${{ env.cache_id }} 30 | path: .cache 31 | restore-keys: | 32 | mkdocs-material- 33 | 34 | - run: python -m pip install -v -e .[docs] 35 | - run: mkdocs gh-deploy --force 36 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build build-inplace test test-inplace dist redist install install-from-source clean uninstall publish-test publish 2 | 3 | build: 4 | python3 setup.py build 5 | 6 | build-inplace: 7 | python3 setup.py build_ext --inplace 8 | 9 | test-inplace: 10 | PYTHONPATH=. pytest 11 | 12 | test: 13 | python -m pytest 14 | 15 | dist: 16 | python setup.py sdist bdist_wheel 17 | 18 | redist: clean dist 19 | 20 | install: 21 | python -m pip install . 22 | 23 | install-from-source: dist 24 | python -m pip install dist/viprs-*.tar.gz 25 | 26 | clean: 27 | $(RM) -r build dist *.egg-info 28 | $(RM) -r viprs/model/vi/*.c viprs/model/vi/*.cpp 29 | $(RM) -r viprs/utils/*.c viprs/utils/*.cpp 30 | $(RM) -r viprs/model/vi/*.so viprs/utils/*.so 31 | $(RM) -r .pytest_cache .tox temp output 32 | find . -name __pycache__ -exec rm -r {} + 33 | 34 | uninstall: 35 | python -m pip uninstall viprs 36 | 37 | publish-test: 38 | python -m twine upload -r testpypi dist/* --verbose 39 | 40 | publish: 41 | python -m twine upload dist/* --verbose 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shadi Zabad, McGill University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /viprs/eval/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .binary_metrics import * 3 | from .continuous_metrics import * 4 | 5 | # Define a dictionary that maps evaluation metric names to their respective functions: 6 | eval_metric_names = { 7 | 'Pearson_R': pearson_r, 8 | 'Spearman_R': spearman_r, 9 | 'MSE': mse, 10 | 'R2': r2, 11 | 'R2_residualized_target': r2_residualized_target, 12 | 'Incremental_R2': incremental_r2, 13 | 'Partial_Correlation': partial_correlation, 14 | 'AUROC': roc_auc, 15 | 'AUPRC': pr_auc, 16 | 'Avg_Precision': avg_precision, 17 | 'F1_Score': f1, 18 | 'Liability_R2': liability_r2, 19 | 'Liability_Probit_R2': liability_probit_r2, 20 | 'Liability_Logit_R2': liability_logit_r2, 21 | 'Nagelkerke_R2': nagelkerke_r2, 22 | 'CoxSnell_R2': cox_snell_r2, 23 | 'McFadden_R2': mcfadden_r2 24 | } 25 | 26 | # Define a list of metrics that can work with or require 27 | # covariates to be computed: 28 | eval_incremental_metrics = [ 29 | 'Incremental_R2', 30 | 'R2_residualized_target', 31 | 'Partial_Correlation', 32 | 'Liability_R2', 33 | 'Liability_Probit_R2', 34 | 'Liability_Logit_R2', 35 | 'Nagelkerke_R2', 36 | 'CoxSnell_R2', 37 | 'McFadden_R2' 38 | ] 39 | -------------------------------------------------------------------------------- /.github/workflows/ci-linux.yml: -------------------------------------------------------------------------------- 1 | name: Linux-CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up miniconda 18 | uses: conda-incubator/setup-miniconda@v3 19 | with: 20 | channels: conda-forge 21 | python-version: ${{ matrix.python-version }} 22 | 23 | - name: Set up Conda environment 24 | shell: "bash -l {0}" 25 | run: > 26 | conda create --name "viprs_ci" -c conda-forge -c anaconda 27 | python=${{matrix.python-version}} pip wheel compilers openblas libcblas -y 28 | 29 | - name: Show info about `viprs_ci` environment 30 | shell: "bash -l {0}" 31 | run: | 32 | conda list --show-channel-urls -n viprs_ci 33 | 34 | - name: Install viprs 35 | shell: "bash -l {0}" 36 | run: | 37 | conda activate viprs_ci 38 | python -m pip install -v -e .[test] 39 | 40 | - name: Run tests 41 | shell: "bash -l {0}" 42 | run: | 43 | conda activate viprs_ci 44 | pytest -v 45 | -------------------------------------------------------------------------------- /.github/workflows/ci-windows.yml: -------------------------------------------------------------------------------- 1 | name: Windows-CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: windows-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up miniconda 18 | uses: conda-incubator/setup-miniconda@v3 19 | with: 20 | channels: conda-forge 21 | python-version: ${{ matrix.python-version }} 22 | 23 | - name: Set up Conda environment 24 | shell: "bash -l {0}" 25 | run: > 26 | conda create --name "viprs_ci" -c conda-forge -c anaconda 27 | python=${{matrix.python-version}} pip wheel pkg-config compilers openblas -y 28 | 29 | - name: Show info about `viprs_ci` environment 30 | shell: "bash -l {0}" 31 | run: | 32 | conda list --show-channel-urls -n viprs_ci 33 | 34 | - name: Install viprs 35 | shell: "bash -l {0}" 36 | run: | 37 | conda activate viprs_ci 38 | python -m pip install -v -e .[test] 39 | 40 | - name: Run tests 41 | shell: "bash -l {0}" 42 | run: | 43 | conda activate viprs_ci 44 | pytest -v -------------------------------------------------------------------------------- /.github/workflows/ci-osx.yml: -------------------------------------------------------------------------------- 1 | name: OSX-CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [macos-13, macos-14] 12 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 13 | 14 | steps: 15 | - name: Checkout code 16 | uses: actions/checkout@v4 17 | 18 | - name: Set up miniconda 19 | uses: conda-incubator/setup-miniconda@v3 20 | with: 21 | channels: conda-forge 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Set up Conda environment 25 | shell: "bash -l {0}" 26 | run: > 27 | conda create --name "viprs_ci" -c conda-forge -c anaconda 28 | python=${{matrix.python-version}} pip wheel compilers openblas -y 29 | 30 | - name: Show info about `viprs_ci` environment 31 | shell: "bash -l {0}" 32 | run: | 33 | conda list --show-channel-urls -n viprs_ci 34 | 35 | - name: Install viprs 36 | shell: "bash -l {0}" 37 | run: | 38 | conda activate viprs_ci 39 | python -m pip install -v -e .[test] 40 | 41 | - name: Run tests 42 | shell: "bash -l {0}" 43 | run: | 44 | conda activate viprs_ci 45 | pytest -v -------------------------------------------------------------------------------- /viprs/plot/diagnostics.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | 4 | 5 | def plot_history(prs_model, quantity=None): 6 | """ 7 | This function plots the optimization history for various model parameters and/or objectives. For 8 | every iteration step, we generally save quantities such as the ELBO, the heritability, etc. For the purposes 9 | of debugging and checking model convergence, it is useful to visually observe the trajectory 10 | of these quantities as a function of training iteration. 11 | 12 | :param prs_model: A `VIPRS` (or its derived classes) object. 13 | :param quantity: The quantities to plot (e.g. `ELBO`, `heritability`, etc.). 14 | 15 | :return: A seaborn `FacetGrid` object containing the plots. 16 | """ 17 | 18 | if quantity is None: 19 | quantity = prs_model.history.keys() 20 | elif isinstance(quantity, str): 21 | quantity = [quantity] 22 | 23 | q_dfs = [] 24 | 25 | for attr in quantity: 26 | 27 | df = pd.DataFrame({'Value': prs_model.history[attr]}) 28 | df.reset_index(inplace=True) 29 | df.columns = ['Step', 'Value'] 30 | df['Quantity'] = attr 31 | 32 | q_dfs.append(df) 33 | 34 | q_dfs = pd.concat(q_dfs) 35 | 36 | g = sns.relplot( 37 | data=q_dfs, x="Step", y="Value", 38 | row="Quantity", 39 | facet_kws={'sharey': False, 'sharex': True}, 40 | kind="scatter", 41 | marker="." 42 | ) 43 | 44 | return g 45 | -------------------------------------------------------------------------------- /containers/cli.Dockerfile: -------------------------------------------------------------------------------- 1 | # Usage: 2 | # ** Step 1 ** Build the docker image: 3 | # docker build -f cli.Dockerfile -t viprs-cli . 4 | # ** Step 2** Run the docker container in interactive shell mode: 5 | # docker run -it viprs-cli /bin/bash 6 | # ** Step 3** Test viprs fit: 7 | # viprs_fit -h 8 | 9 | FROM python:3.11-slim-buster 10 | 11 | LABEL authors="Shadi Zabad" 12 | LABEL version="0.1.3" 13 | LABEL description="Docker image containing all requirements to run the commandline scripts in the VIPRS package" 14 | 15 | # Install system dependencies 16 | RUN apt-get update && apt-get install -y \ 17 | unzip \ 18 | wget \ 19 | pkg-config \ 20 | g++ gcc \ 21 | libopenblas-dev \ 22 | libomp-dev 23 | 24 | # Download and setup plink2: 25 | RUN mkdir -p /software && \ 26 | wget https://s3.amazonaws.com/plink2-assets/alpha5/plink2_linux_avx2_20240105.zip -O /software/plink2.zip && \ 27 | unzip /software/plink2.zip -d /software && \ 28 | rm /software/plink2.zip 29 | 30 | # Download and setup plink1.9: 31 | RUN mkdir -p /software && \ 32 | wget https://s3.amazonaws.com/plink1-assets/plink_linux_x86_64_20231211.zip -O /software/plink.zip && \ 33 | unzip /software/plink.zip -d /software && \ 34 | rm /software/plink.zip 35 | 36 | # Add plink1.9 and plink2 to PATH: 37 | RUN echo 'export PATH=$PATH:/software' >> ~/.bashrc 38 | 39 | # Install viprs package from PyPI 40 | RUN pip install --upgrade pip viprs 41 | 42 | # Test the installation 43 | RUN viprs_fit -h 44 | RUN viprs_score -h 45 | RUN viprs_evaluate -h 46 | -------------------------------------------------------------------------------- /viprs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | __version__ = '0.1.3' 3 | __release_date__ = 'April 2025' 4 | 5 | from .model import VIPRS 6 | from .utils.data_utils import * 7 | 8 | 9 | def make_ascii_logo(desc=None, left_padding=None): 10 | """ 11 | Generate an ASCII logo for the VIPRS package. 12 | :param desc: A string description to be added below the logo. 13 | :param left_padding: Padding to the left of the logo. 14 | 15 | :return: A string containing the ASCII logo. 16 | """ 17 | 18 | logo = r""" 19 | _____ 20 | ___ _____(_)________ ________________ 21 | __ | / /__ / ___ __ \__ ___/__ ___/ 22 | __ |/ / _ / __ /_/ /_ / _(__ ) 23 | _____/ /_/ _ .___/ /_/ /____/ 24 | /_/ 25 | """ 26 | 27 | lines = logo.replace(' ', '\u2001').splitlines()[1:] 28 | lines.append("Variational Inference of Polygenic Risk Scores") 29 | lines.append(f"Version: {__version__} | Release date: {__release_date__}") 30 | lines.append("Author: Shadi Zabad, McGill University") 31 | 32 | # Find the maximum length of the lines 33 | max_len = max([len(l) for l in lines]) 34 | if desc is not None: 35 | max_len = max(max_len, len(desc)) 36 | 37 | # Pad the lines to the same length 38 | for i, l in enumerate(lines): 39 | lines[i] = l.center(max_len) 40 | 41 | # Add separators at the top and bottom 42 | lines.insert(0, '*' * max_len) 43 | lines.append('*' * max_len) 44 | 45 | if desc is not None: 46 | lines.append(desc.center(max_len)) 47 | 48 | if left_padding is not None: 49 | for i, l in enumerate(lines): 50 | lines[i] = '\u2001' * left_padding + l 51 | 52 | return "\n".join(lines) 53 | -------------------------------------------------------------------------------- /containers/jupyter.Dockerfile: -------------------------------------------------------------------------------- 1 | # Usage: 2 | # ** Step 1 ** Build the docker image: 3 | # docker build -f ../vemPRS/containers/jupyter.Dockerfile -t viprs-jupyter . 4 | # ** Step 2 ** Run the docker container (pass the appropriate port): 5 | # docker run -p 8888:8888 viprs-jupyter 6 | # ** Step 3 ** Open the link in your browser: 7 | # http://localhost:8888 8 | 9 | 10 | FROM python:3.11-slim-buster 11 | 12 | LABEL authors="Shadi Zabad" 13 | LABEL version="0.1" 14 | LABEL description="Docker image containing all requirements to run the VIPRS package in a Jupyter Notebook" 15 | 16 | # Install system dependencies 17 | RUN apt-get update && apt-get install -y \ 18 | unzip \ 19 | wget \ 20 | pkg-config \ 21 | g++ gcc \ 22 | libopenblas-dev \ 23 | libomp-dev 24 | 25 | # Download and setup plink2: 26 | RUN mkdir -p /software && \ 27 | wget https://s3.amazonaws.com/plink2-assets/alpha5/plink2_linux_avx2_20240105.zip -O /software/plink2.zip && \ 28 | unzip /software/plink2.zip -d /software && \ 29 | rm /software/plink2.zip 30 | 31 | # Download and setup plink1.9: 32 | RUN mkdir -p /software && \ 33 | wget https://s3.amazonaws.com/plink1-assets/plink_linux_x86_64_20231211.zip -O /software/plink.zip && \ 34 | unzip /software/plink.zip -d /software && \ 35 | rm /software/plink.zip 36 | 37 | # Add plink1.9 and plink2 to PATH: 38 | RUN echo 'export PATH=$PATH:/software' >> ~/.bashrc 39 | 40 | # Install viprs package from PyPI 41 | RUN pip install --upgrade pip viprs jupyterlab 42 | 43 | # Expose the port Jupyter Lab will be served on 44 | EXPOSE 8888 45 | 46 | # Set the working directory 47 | WORKDIR /viprs_dir 48 | 49 | # Copy the current directory contents into the container at /app 50 | COPY . /viprs_dir 51 | 52 | # Run Jupyter Lab 53 | CMD ["jupyter", "lab", "--ip=0.0.0.0", "--allow-root", "--NotebookApp.token=''"] 54 | -------------------------------------------------------------------------------- /docs/commandline/overview.md: -------------------------------------------------------------------------------- 1 | In addition to the python package interface, users may also opt to use `viprs` via commandline scripts. 2 | The commandline interface is designed to be user-friendly and to provide a variety of options for the user to 3 | customize the inference process. 4 | 5 | When you install `viprs` using `pip`, the commandline scripts are automatically installed on your system and 6 | are available for use. The following scripts are meant to facilitate the entire pipeline of polygenic score inference, 7 | from fitting and estimating the posterior distribution of the variant effect sizes to predicting the PRS for a set of 8 | test individuals and evaluating the performance of the PRS predictions on held out samples. 9 | 10 | * [`viprs_fit`](viprs_fit.md): This script is used to fit the variational PRS model to the GWAS summary statistics and to estimate the 11 | posterior distribution of the variant effect sizes. The script provides a variety of options for the user to 12 | customize the inference process, including the choice of prior distributions and the choice of 13 | optimization algorithms. 14 | 15 | * [`viprs_score`](viprs_score.md): This script is used to predict the PRS for a set of individuals using the 16 | estimated variant effect sizes from the `viprs_fit` script. This is the script that generates the PRS per 17 | individual. 18 | 19 | * [`viprs_evaluate`](viprs_evaluate.md): This script is used to evaluate the performance of the PRS predictions 20 | using the PRS computed in the previous step. The script provides a variety of 21 | options for the user to customize the evaluation process, including the choice of performance metrics and 22 | the choice of evaluation datasets. 23 | 24 | 25 | ## TODO 26 | - Create a `nextflow` pipeline that runs all of the above steps in a single command. 27 | -------------------------------------------------------------------------------- /tests/conda_manual_testing.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # A script to test the package with different Python versions manually using conda 4 | # May be useful for sanity checks before pushing changes to the repository. 5 | 6 | # Usage: 7 | # $ source tests/conda_manual_testing.sh 8 | 9 | # ============================================================================== 10 | 11 | if [[ -t 1 ]]; then 12 | set -e # Enable exit on error, only in non-interactive sessions 13 | fi 14 | 15 | # Activate the base conda environment 16 | source activate base 17 | 18 | SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 19 | echo "Running tests from: $SCRIPT_DIR" 20 | 21 | # Define Python versions (add more here if needed) 22 | python_versions=("3.8" "3.9" "3.10" "3.11" "3.12") 23 | 24 | # ============================================================================== 25 | 26 | # Loop over Python versions 27 | for version in "${python_versions[@]}" 28 | do 29 | # Create a new conda environment for the Python version 30 | conda create --name "viprs_py$version" python="$version" -y || return 1 31 | 32 | # Activate the conda environment 33 | conda activate "viprs_py$version" 34 | 35 | # Add some of the required dependencies: 36 | conda install -c conda-forge -c anaconda pip wheel compilers openblas -y 37 | 38 | # Check python version: 39 | python --version 40 | 41 | # Install viprs 42 | make clean 43 | python -m pip install -e .[test] 44 | 45 | # List the installed packages: 46 | python -m pip list 47 | 48 | # Run pytest 49 | python -m pytest -v 50 | 51 | # Test the CLI scripts: 52 | bash "$SCRIPT_DIR/test_cli.sh" 53 | 54 | # Deactivate the conda environment 55 | conda deactivate 56 | 57 | # Remove the conda environment 58 | conda env remove --name "viprs_py$version" -y 59 | done 60 | -------------------------------------------------------------------------------- /docs/api/overview.md: -------------------------------------------------------------------------------- 1 | ## Models 2 | 3 | * [BayesPRSModel](model/BayesPRSModel.md): A base class for all Bayesian PRS models. 4 | * [VIPRS](model/VIPRS.md): Implementation of VIPRS with the "**spike-and-slab**" prior. 5 | * Implementation of VIPRS with **other priors**: 6 | * [VIPRSMix](model/VIPRSMix.md): VIPRS with a sparse Gaussian mixture prior. 7 | * **Hyperparameter Tuning**: Models/modules for performing hyperparameter search with `VIPRS` models. 8 | * [Hyperparameter grid](model/gridsearch/HyperparameterGrid.md): A utility class to help construct grids over model hyperparameters. 9 | * [HyperparameterSearch](model/gridsearch/HyperparameterSearch.md) 10 | * [VIPRSGrid](model/gridsearch/VIPRSGrid.md) 11 | * [grid_utils](model/gridsearch/grid_utils.md): Utilities for performing model selection/averaging. 12 | * **Baseline Models**: 13 | * [LDPredInf](model/LDPredInf.md): Implementation of the LDPred-inf model. 14 | 15 | ## Model Evaluation 16 | 17 | * [Binary metrics](eval/binary_metrics.md): Evaluation metrics for binary (case-control) phenotypes. 18 | * [Continuous metrics](eval/continuous_metrics.md): Evaluation metrics for continuous phenotypes. 19 | * [Pseudo metrics](eval/pseudo_metrics.md): Evaluation metrics based on GWAS summary statistics. 20 | 21 | ## Utilities 22 | 23 | * [Data utilities](utils/data_utils.md): Utilities for downloading and processing relevant data. 24 | * [Compute utilities](utils/compute_utils.md): Utilities for computing various statistics / quantities over python data structures. 25 | * [Exceptions](utils/exceptions.md): Custom exceptions used in the package. 26 | * [OptimizeResult](utils/OptimizeResult.md): A class to store the result of an optimization routine. 27 | 28 | ## Plotting 29 | 30 | * [Diagnostic plots](plot/diagnostics.md): Functions for plotting various quantities / results from VIPRS or other PRS models. 31 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Variational Inference of Polygenic Risk Scores (VIPRS) 2 | 3 | This site contains documentation, tutorials, and examples for using the `viprs` package for the purposes of 4 | inferring polygenic risk scores (PRS) from GWAS summary statistics. The `viprs` package is a python package 5 | that uses variational inference to estimate the posterior distribution of variant effect sizes conditional 6 | on the GWAS summary statistics. The package is designed to be fast and accurate, and to provide a 7 | variety of options for the user to customize the inference process. 8 | 9 | The details of the method and algorithms are described in detail in the following paper(s): 10 | 11 | 12 | > Zabad, S., Gravel, S., & Li, Y. (2023). **Fast and accurate Bayesian polygenic risk modeling with variational inference.** 13 | The American Journal of Human Genetics, 110(5), 741–761. https://doi.org/10.1016/j.ajhg.2023.03.009 14 | 15 | 16 | ## Helpful links 17 | 18 | * [API Reference](api/overview.md) 19 | * [Installation](installation.md) 20 | * [Getting Started](getting_started.md) 21 | * [Command Line Scripts](commandline/overview.md) 22 | * [Download Reference LD matrices](download_ld.md) 23 | * [Project homepage on `GitHub`](https://github.com/shz9/viprs) 24 | * [Sister package `magenpy`](https://github.com/shz9/magenpy) 25 | 26 | ## Software contributions 27 | 28 | The latest version of the `viprs` package was developed in collaboration between research scientists 29 | at McGill University and Intel Labs. 30 | 31 | * Contributors from **McGill University**: 32 | * [Shadi Zabad](https://github.com/shz9) 33 | * [Yue Li](https://www.cs.mcgill.ca/~yueli/) 34 | * [Simon Gravel](https://gravellab.github.io/) 35 | * Contributors from **Intel Labs**: 36 | * [Chirayu Anant Haryan](https://in.linkedin.com/in/chirayu-haryan) 37 | * [Sanchit Misra](https://sanchit-misra.github.io/) 38 | 39 | ## Contact 40 | 41 | If you have any questions or issues, please feel free to open an [issue](https://github.com/shz9/viprs/issues) 42 | on the `GitHub` repository or contact us directly at: 43 | 44 | * [Shadi Zabad](mailto:shadi.zabad@mail.mcgill.ca) 45 | * [Yue Li](mailto:yueli@cs.mcgill.ca) 46 | * [Simon Gravel](mailto:simon.gravel@mcgill.ca) 47 | 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # `viprs`: Variational Inference of Polygenic Risk Scores 2 | 3 | [![PyPI pyversions](https://img.shields.io/pypi/pyversions/viprs.svg)](https://pypi.python.org/pypi/viprs/) 4 | [![PyPI version fury.io](https://badge.fury.io/py/viprs.svg)](https://pypi.python.org/pypi/viprs/) 5 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 6 | 7 | [![Linux CI](https://github.com/shz9/viprs/actions/workflows/ci-linux.yml/badge.svg)](https://github.com/shz9/viprs/actions/workflows/ci-linux.yml) 8 | [![MacOS CI](https://github.com/shz9/viprs/actions/workflows/ci-osx.yml/badge.svg)](https://github.com/shz9/viprs/actions/workflows/ci-osx.yml) 9 | [![Windows CI](https://github.com/shz9/viprs/actions/workflows/ci-windows.yml/badge.svg)](https://github.com/shz9/viprs/actions/workflows/ci-windows.yml) 10 | [![Docs Build](https://github.com/shz9/viprs/actions/workflows/ci-docs.yml/badge.svg)](https://github.com/shz9/viprs/actions/workflows/ci-docs.yml) 11 | [![Binary wheels](https://github.com/shz9/viprs/actions/workflows/wheels.yml/badge.svg)](https://github.com/shz9/viprs/actions/workflows/wheels.yml) 12 | 13 | 14 | [![Downloads](https://static.pepy.tech/badge/viprs)](https://pepy.tech/project/viprs) 15 | [![Downloads](https://static.pepy.tech/badge/viprs/month)](https://pepy.tech/project/viprs) 16 | 17 | 18 | `viprs` is a python package that implements variational inference techniques to estimate the posterior distribution 19 | of variant effect sizes conditional on the GWAS summary statistics. The package is designed to be fast and accurate, 20 | and to provide a variety of options for the user to customize the inference process. 21 | Highlighted features: 22 | 23 | * The coordinate ascent algorithms are written in `C/C++` and `cython` for improved speed and efficiency. 24 | * The code is written in object-oriented form, allowing the user to extend and 25 | experiment with existing implementations. 26 | * Different priors on the effect size: Spike-and-slab, Sparse mixture, etc. 27 | * We also provide scripts for different hyperparameter tuning strategies, including: 28 | Grid search, Bayesian optimization, Bayesian model averaging. 29 | * Easy and straightforward interfaces for computing PRS from fitted models. 30 | * Implementation for a wide variety of evaluation metrics for both binary and continuous phenotypes. 31 | 32 | 33 | ### Helpful links 34 | 35 | - [Documentation](https://shz9.github.io/viprs/) 36 | - [Citation / BibTeX records](./CITATION.md) 37 | - [Report issues/bugs](https://github.com/shz9/viprs/issues) -------------------------------------------------------------------------------- /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Wheels 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | push: 7 | branches: 8 | - main 9 | release: 10 | types: 11 | - published 12 | 13 | jobs: 14 | build_wheels: 15 | name: Build wheels on ${{ matrix.os }} 16 | runs-on: ${{ matrix.os }} 17 | strategy: 18 | matrix: 19 | # macos-13 is an intel runner, macos-14 is Apple Silicon 20 | # We'll skip building wheels for windows (windows-latest) for now. 21 | # There are issues to figure out regarding dependency management. 22 | os: [ubuntu-latest, macos-13, macos-14] 23 | env: 24 | CIBW_SKIP: "pp* cp36-* *-musllinux_* *win32 *_i686 *_s390x" 25 | #CIBW_BEFORE_BUILD_WINDOWS: "choco install pkgconfiglite" 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | 30 | - name: Build wheels 31 | uses: pypa/cibuildwheel@v2.17.0 32 | 33 | - uses: actions/upload-artifact@v4 34 | with: 35 | name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} 36 | path: ./wheelhouse/*.whl 37 | 38 | build_sdist: 39 | name: Build source distribution 40 | runs-on: ubuntu-latest 41 | steps: 42 | - uses: actions/checkout@v4 43 | 44 | - uses: actions/setup-python@v5 45 | name: Install Python 46 | with: 47 | python-version: '3.11' 48 | 49 | - name: Build sdist 50 | run: | 51 | python -m pip install --user pipx 52 | python -m pipx ensurepath 53 | python -m pipx run build --sdist 54 | 55 | - name: test install 56 | run: | 57 | python -m pip install --upgrade pip 58 | python -m pip install dist/viprs*.tar.gz 59 | viprs_fit -h 60 | 61 | - uses: actions/upload-artifact@v4 62 | with: 63 | name: cibw-sdist 64 | path: dist/*.tar.gz 65 | 66 | upload_pypi: 67 | needs: [build_wheels, build_sdist] 68 | runs-on: ubuntu-latest 69 | environment: pypi 70 | permissions: 71 | id-token: write 72 | if: github.event_name == 'release' && github.event.action == 'published' 73 | # or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this) 74 | # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 75 | steps: 76 | - uses: actions/download-artifact@v4 77 | with: 78 | # unpacks all CIBW artifacts into dist/ 79 | pattern: cibw-* 80 | path: dist 81 | merge-multiple: true 82 | 83 | - uses: pypa/gh-action-pypi-publish@release/v1 84 | with: 85 | user: __token__ 86 | password: ${{ secrets.PYPI_API_TOKEN }} 87 | # To test: repository-url: https://test.pypi.org/legacy/ -------------------------------------------------------------------------------- /docs/commandline/viprs_score.md: -------------------------------------------------------------------------------- 1 | Compute Polygenic Scores using inferred variant effect sizes (`viprs_score`) 2 | --- 3 | 4 | The `viprs_score` script is used to compute the polygenic risk scores (PRS) for a set of individuals 5 | using the estimated variant effect sizes from the `viprs_fit` script. This is the script that generates 6 | the PRS per individual. 7 | 8 | A full listing of the options available for the `viprs_score` script can be found by running the 9 | following command in your terminal: 10 | 11 | ```bash 12 | viprs_score -h 13 | ``` 14 | 15 | Which outputs the following help message: 16 | 17 | ```bash 18 |           ********************************************** 19 |                    _____                           20 |            ___   _____(_)________ ________________ 21 |            __ | / /__  / ___  __ \__  ___/__  ___/ 22 |            __ |/ / _  /  __  /_/ /_  /    _(__  )  23 |            _____/  /_/   _  .___/ /_/     /____/   24 |                          /_/                       25 |                 26 |           Variational Inference of Polygenic Risk Scores 27 |            Version: 0.1.3 | Release date: April 2025 28 |            Author: Shadi Zabad, McGill University 29 |           ********************************************** 30 |           < Compute Polygenic Scores for Test Samples > 31 | 32 | usage: viprs_score [-h] -f FIT_FILES --bfile BED_FILES --output-file OUTPUT_FILE [--temp-dir TEMP_DIR] [--keep KEEP] [--extract EXTRACT] 33 | [--backend {xarray,plink}] [--threads THREADS] [--compress] [--log-level {WARNING,CRITICAL,DEBUG,INFO,ERROR}] 34 | 35 | Commandline arguments for computing polygenic scores 36 | 37 | options: 38 | -h, --help show this help message and exit 39 | -f FIT_FILES, --fit-files FIT_FILES 40 | The path to the file(s) with the output parameter estimates from VIPRS. You may use a wildcard here if fit files are stored per- 41 | chromosome (e.g. "prs/chr_*.fit") 42 | --bfile BED_FILES The BED files containing the genotype data. You may use a wildcard here (e.g. "data/chr_*.bed") 43 | --output-file OUTPUT_FILE 44 | The output file where to store the polygenic scores (with no extension). 45 | --temp-dir TEMP_DIR The temporary directory where to store intermediate files. 46 | --keep KEEP A plink-style keep file to select a subset of individuals for the test set. 47 | --extract EXTRACT A plink-style extract file to select a subset of SNPs for scoring. 48 | --backend {xarray,plink} 49 | The backend software used for computations with the genotype matrix. 50 | --threads THREADS The number of threads to use for computations. 51 | --compress Compress the output file 52 | --log-level {WARNING,CRITICAL,DEBUG,INFO,ERROR} 53 | The logging level for the console output. 54 | 55 | ``` -------------------------------------------------------------------------------- /viprs/eval/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def r2_stats(r2_val, n): 5 | """ 6 | Compute the confidence interval and p-value for a given R-squared (proportion of variance 7 | explained) value. 8 | 9 | This function and the formulas therein are based on the following paper 10 | by Momin et al. 2023: https://doi.org/10.1016/j.ajhg.2023.01.004 as well as 11 | the implementation in the R package `PRSmix`: 12 | https://github.com/buutrg/PRSmix/blob/main/R/get_PRS_acc.R#L63 13 | 14 | :param r2_val: The R^2 value to compute the confidence interval/p-value for. 15 | :param n: The sample size used to compute the R^2 value 16 | 17 | :return: A dictionary with the R^2 value, the lower and upper values of the confidence interval, 18 | the p-value, and the standard error of the R^2 metric. 19 | 20 | """ 21 | 22 | assert 0. < r2_val < 1., "R^2 value must be between 0 and 1." 23 | 24 | # Compute the variance of the R^2 value: 25 | r2_var = (4. * r2_val * (1. - r2_val) ** 2 * (n - 2) ** 2) / ((n ** 2 - 1) * (n + 3)) 26 | 27 | # Compute the standard errors for the R^2 value 28 | # as well as the lower and upper values for 29 | # the confidence interval: 30 | r2_se = np.sqrt(r2_var) 31 | lower_r2 = r2_val - 1.97 * r2_se 32 | upper_r2 = r2_val + 1.97 * r2_se 33 | 34 | from scipy import stats 35 | 36 | # Compute the p-value assuming a Chi-squared distribution with 1 degree of freedom: 37 | pval = stats.chi2.sf((r2_val / r2_se) ** 2, df=1) 38 | 39 | return { 40 | 'R2': r2_val, 41 | 'Lower_R2': lower_r2, 42 | 'Upper_R2': upper_r2, 43 | 'P_Value': pval, 44 | 'SE': r2_se, 45 | } 46 | 47 | 48 | def fit_linear_model(y, x, family='gaussian', link=None, add_intercept=False): 49 | """ 50 | Fit a linear model to the data `x` and `y` and return the model object. 51 | 52 | :param y: The independent variable (a numpy vector) 53 | :param x: The design matrix (a pandas DataFrame) 54 | :param family: The family of the model. Must be either 'gaussian' or 'binomial'. 55 | :param link: The link function to use for the model. If None, the default link function. 56 | :param add_intercept: If True, add an intercept term to the model. 57 | """ 58 | 59 | assert y.shape[0] == x.shape[0], ("The number of rows in the design matrix " 60 | "and the independent variable must match.") 61 | assert family in ('gaussian', 'binomial'), "The family must be either 'gaussian' or 'binomial'." 62 | if family == 'binomial': 63 | assert link in ('logit', 'probit', None), "The link function must be either 'logit', 'probit' or None." 64 | 65 | import statsmodels.api as sm 66 | 67 | if add_intercept: 68 | x = sm.add_constant(x) 69 | 70 | if family == 'gaussian': 71 | return sm.OLS(y, x).fit() 72 | elif family == 'binomial': 73 | if link == 'logit' or link is None: 74 | return sm.Logit(y, x).fit(disp=0) 75 | elif link == 'probit': 76 | return sm.Probit(y, x).fit(disp=0) 77 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Variational Inference of Polygenic Risk Scores (VIPRS) 2 | 3 | repo_name: viprs 4 | repo_url: https://github.com/shz9/viprs 5 | 6 | theme: 7 | name: "material" 8 | icon: 9 | repo: fontawesome/brands/github 10 | features: 11 | - announce.dismiss 12 | - content.action.edit 13 | - content.action.view 14 | - content.code.annotate 15 | - content.code.copy 16 | # - content.code.select 17 | # - content.footnote.tooltips 18 | # - content.tabs.link 19 | - content.tooltips 20 | # - header.autohide 21 | - navigation.expand 22 | - navigation.footer 23 | - navigation.indexes 24 | # - navigation.instant 25 | # - navigation.instant.prefetch 26 | # - navigation.instant.progress 27 | # - navigation.prune 28 | #- navigation.sections 29 | #- navigation.tabs 30 | # - navigation.tabs.sticky 31 | #- navigation.top 32 | - navigation.tracking 33 | - search.highlight 34 | - search.share 35 | - search.suggest 36 | - toc.follow 37 | # - toc.integrate 38 | palette: 39 | - media: "(prefers-color-scheme)" 40 | toggle: 41 | icon: material/link 42 | name: Switch to light mode 43 | - media: "(prefers-color-scheme: light)" 44 | scheme: default 45 | primary: indigo 46 | accent: indigo 47 | toggle: 48 | icon: material/toggle-switch 49 | name: Switch to dark mode 50 | - media: "(prefers-color-scheme: dark)" 51 | scheme: slate 52 | primary: black 53 | accent: indigo 54 | toggle: 55 | icon: material/toggle-switch-off 56 | name: Switch to system preference 57 | font: 58 | text: Roboto 59 | code: Roboto Mono 60 | 61 | plugins: 62 | - mkdocstrings: 63 | handlers: 64 | python: 65 | paths: [viprs] # search packages in the src folder 66 | options: 67 | docstring_style: sphinx 68 | - search 69 | - autorefs 70 | 71 | 72 | markdown_extensions: 73 | - admonition 74 | - pymdownx.highlight: 75 | anchor_linenums: true 76 | line_spans: __span 77 | pygments_lang_class: true 78 | - toc: 79 | permalink: true 80 | - pymdownx.inlinehilite 81 | - pymdownx.snippets 82 | - pymdownx.superfences 83 | - pymdownx.magiclink: 84 | normalize_issue_symbols: true 85 | repo_url_shorthand: true 86 | user: shz9 87 | repo: viprs 88 | 89 | nav: 90 | - "Home": index.md 91 | - "Installation": installation.md 92 | - "Getting Started": getting_started.md 93 | - "Download LD Reference": download_ld.md 94 | - "Tutorials": tutorials/overview.md 95 | - "Command Line Scripts": 96 | - "Overview": commandline/overview.md 97 | - "viprs_fit": commandline/viprs_fit.md 98 | - "viprs_score": commandline/viprs_score.md 99 | - "viprs_evaluate": commandline/viprs_evaluate.md 100 | - "Report issues/bugs": "https://github.com/shz9/viprs/issues" 101 | - "FAQ": faq.md 102 | - "Citation": citation.md 103 | - "API Reference": api/overview.md -------------------------------------------------------------------------------- /docs/installation.md: -------------------------------------------------------------------------------- 1 | The `viprs` software is written in `C/C++` and `Cython/Python3` and is designed to be fast and accurate. 2 | The software is designed to be used in a variety of computing environments, including local workstations, 3 | shared computing environments, and cloud-based computing environments. Because of the dependencies on `C/C++`, you need 4 | to ensure that a `C/C++` Compiler (with appropriate flags) is present on your system. 5 | 6 | ## Requirements 7 | 8 | Building the `viprs` package requires the following dependencies: 9 | 10 | * `python` (>=3.8) 11 | * `C/C++` Compilers 12 | * `cython` 13 | * `numPy` 14 | * `pkg-config` 15 | * `sciPy` (>=1.5.4) 16 | 17 | To take full advantage of the **parallel processing** capabilities of the package, you will also need to make sure that 18 | the following packages/libraries are available: 19 | 20 | * `OpenMP` 21 | * `BLAS` 22 | 23 | ### Setting up the environment with `conda` 24 | 25 | If you can use `Anaconda` or `miniconda` to manage your Python environment, we **recommend** using them to create 26 | a new environment with the required dependencies as follows: 27 | 28 | ```bash 29 | python_version=3.11 # Change python version here if needed 30 | conda create --name "viprs_env" -c anaconda -c conda-forge python=$python_version compilers pkg-config openblas -y 31 | conda activate viprs_env 32 | ``` 33 | 34 | Using `conda` to setup and manage your environment is especially *recommended* if you have trouble compiling 35 | the `C/C++` extensions on your system. 36 | 37 | ## Installation 38 | 39 | ### Using `pip` 40 | 41 | The package is available for easy installation via the Python Package Index (`pypi`) can 42 | be installed using `pip`: 43 | 44 | ```bash 45 | python -m pip install viprs>=0.1 46 | ``` 47 | 48 | ### Building from source 49 | 50 | You may also build the package from source, by cloning the repository and 51 | running the `make install` command: 52 | 53 | ```bash 54 | git clone https://github.com/shz9/viprs.git 55 | cd viprs 56 | make install 57 | ``` 58 | 59 | ### Using a virtual environment 60 | 61 | If you wish to use `viprs` on a shared computing environment or cluster, 62 | it is recommended that you install the package in a virtual environment. Here's a quick 63 | example of how to install `viprs` on a SLURM-based cluster: 64 | 65 | ```bash 66 | module load python/3.11 67 | python3 -m venv viprs_env # Assumes venv is available 68 | source viprs_env/bin/activate 69 | python -m pip install --upgrade pip 70 | python -m pip install viprs>=0.1 71 | ``` 72 | 73 | ### Using `Docker` containers 74 | 75 | If you are using `Docker` containers, you can build a container with the `viprs` package 76 | and all its dependencies by downloading the relevant `Dockerfile` from the 77 | [repository](https://github.com/shz9/viprs/tree/master/containers) and building it 78 | as follows: 79 | 80 | ```bash 81 | # Build the docker image: 82 | docker build -f cli.Dockerfile -t viprs-cli . 83 | # Run the container in interactive mode: 84 | docker run -it viprs-cli /bin/bash 85 | # Test that the package installed successfully: 86 | viprs_fit -h 87 | ``` 88 | 89 | We plan to publish pre-built `Docker` images on `DockerHub` in the future. 90 | -------------------------------------------------------------------------------- /viprs/model/vi/e_step_cpp.pxd: -------------------------------------------------------------------------------- 1 | from cython cimport floating 2 | cimport numpy as cnp 3 | 4 | # -------------------------------------------------- 5 | # Define fused data types: 6 | 7 | ctypedef fused indptr_type: 8 | cnp.int32_t 9 | cnp.int64_t 10 | 11 | ctypedef fused noncomplex_numeric: 12 | cnp.int8_t 13 | cnp.int16_t 14 | cnp.int32_t 15 | cnp.int64_t 16 | cnp.float32_t 17 | cnp.float64_t 18 | 19 | # -------------------------------------------------- 20 | 21 | cdef void cpp_blas_axpy(floating[::1] v1, floating[::1] v2, floating alpha) noexcept nogil 22 | cdef floating cpp_blas_dot(floating[::1] v1, floating[::1] v2) noexcept nogil 23 | 24 | 25 | cpdef void cpp_e_step(int[::1] ld_left_bound, 26 | indptr_type[::1] ld_indptr, 27 | noncomplex_numeric[::1] ld_data, 28 | floating[::1] std_beta, 29 | floating[::1] var_gamma, 30 | floating[::1] var_mu, 31 | floating[::1] eta, 32 | floating[::1] q, 33 | floating[::1] eta_diff, 34 | floating[::1] u_logs, 35 | floating[::1] half_var_tau, 36 | floating[::1] mu_mult, 37 | floating dq_scale, 38 | int threads, 39 | bint low_memory) noexcept nogil 40 | 41 | 42 | cpdef void cpp_e_step_mixture(int[::1] ld_left_bound, 43 | indptr_type[::1] ld_indptr, 44 | noncomplex_numeric[::1] ld_data, 45 | floating[::1] std_beta, 46 | floating[:, ::1] var_gamma, 47 | floating[:, ::1] var_mu, 48 | floating[::1] eta, 49 | floating[::1] q, 50 | floating[::1] eta_diff, 51 | floating[::1] log_null_pi, 52 | floating[:, ::1] u_logs, 53 | floating[:, ::1] half_var_tau, 54 | floating[:, ::1] mu_mult, 55 | floating dq_scale, 56 | int threads, 57 | bint low_memory) noexcept nogil 58 | 59 | cpdef void cpp_e_step_grid(int[::1] ld_left_bound, 60 | indptr_type[::1] ld_indptr, 61 | noncomplex_numeric[::1] ld_data, 62 | floating[::1] std_beta, 63 | floating[::1, :] var_gamma, 64 | floating[::1, :] var_mu, 65 | floating[::1, :] eta, 66 | floating[::1, :] q, 67 | floating[::1, :] eta_diff, 68 | floating[::1, :] u_logs, 69 | floating[::1, :] half_var_tau, 70 | floating[::1, :] mu_mult, 71 | floating dq_scale, 72 | int[:] active_model_idx, 73 | int threads, 74 | bint low_memory) noexcept nogil 75 | -------------------------------------------------------------------------------- /tests/test_cli.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ -t 1 ]]; then 4 | set -e # Enable exit on error, only in non-interactive sessions 5 | fi 6 | 7 | BFILE_PATH=$(python3 -c "import magenpy as mgp; print(mgp.tgp_eur_data_path())") 8 | SUMSTATS_PATH=$(python3 -c "import magenpy as mgp; print(mgp.ukb_height_sumstats_path())") 9 | LD_BLOCKS_PATH="https://bitbucket.org/nygcresearch/ldetect-data/raw/ac125e47bf7ff3e90be31f278a7b6a61daaba0dc/EUR/fourier_ls-all.bed" 10 | 11 | # ------------------------------------------------------------------- 12 | # Use `magenpy_ld` cli script to estimate LD: 13 | 14 | echo "> Estimating LD using the block estimator:" 15 | magenpy_ld --estimator "block" \ 16 | --bfile "$BFILE_PATH" \ 17 | --ld-blocks "$LD_BLOCKS_PATH" \ 18 | --output-dir "output/ld_block/" 19 | 20 | # Check that there's a directory called "output/ld_block/chr_22/": 21 | if [ ! -d "output/ld_block/chr_22/" ]; then 22 | echo "Error: The output directory was not created." 23 | exit 1 24 | fi 25 | 26 | # Check that the directory contains both `.zgroup` and `.zatrs` files: 27 | if [ ! -f "output/ld_block/chr_22/.zgroup" ] || [ ! -f "output/ld_block/chr_22/.zattrs" ]; then 28 | echo "Error: The output directory does not contain the expected files." 29 | exit 1 30 | fi 31 | 32 | # ------------------------------------------------------------------- 33 | # Test the `viprs_fit` cli script: 34 | 35 | echo -e "\n> Testing the VIPRS_EM model... \n" 36 | 37 | viprs_fit -l "output/ld_block/chr_22/" \ 38 | -s "$SUMSTATS_PATH" \ 39 | --sumstats-format "fastgwa" \ 40 | --output-dir "output/viprs_fit/" \ 41 | --output-profiler-metrics 42 | 43 | # Check that the output file exists: 44 | if [ ! -f "output/viprs_fit/VIPRS_EM.fit.gz" ]; then 45 | echo "Error: The output file was not created." 46 | exit 1 47 | fi 48 | 49 | echo -e "\n> Testing the VIPRS_GS model... \n" 50 | 51 | viprs_fit -l "output/ld_block/chr_22/" \ 52 | -s "$SUMSTATS_PATH" \ 53 | --sumstats-format "fastgwa" \ 54 | --hyp-search "GS" \ 55 | --pi-steps 10 \ 56 | --output-dir "output/viprs_fit/" \ 57 | --output-profiler-metrics 58 | 59 | # Check that the output file exists: 60 | if [ ! -f "output/viprs_fit/VIPRS_GS.fit.gz" ]; then 61 | echo "Error: The output file was not created." 62 | exit 1 63 | fi 64 | 65 | # ------------------------------------------------------------------- 66 | # Test the `viprs_score` cli script: 67 | 68 | echo -e "\n> Testing the VIPRS scoring... \n" 69 | viprs_score -f "output/viprs_fit/VIPRS_EM.fit.gz" \ 70 | --bfile "$BFILE_PATH" \ 71 | --output-file "output/viprs_score/scores" \ 72 | --compress 73 | 74 | # Check that the output file exists: 75 | if [ ! -f "output/viprs_score/scores.prs.gz" ]; then 76 | echo "Error: The output file was not created." 77 | exit 1 78 | fi 79 | 80 | # ------------------------------------------------------------------- 81 | # Test the `viprs_evaluate` cli script: 82 | 83 | # TODO: Expand this to include proper testing of evaluation pipeline 84 | viprs_evaluate -h 85 | 86 | # ------------------------------------------------------------------- 87 | # Clean up after computation: 88 | rm -rf output/ 89 | rm -rf temp/ 90 | -------------------------------------------------------------------------------- /docs/commandline/viprs_evaluate.md: -------------------------------------------------------------------------------- 1 | Evaluate Predictive Performance of PRS (`viprs_evaluate`) 2 | --- 3 | 4 | The `viprs_evaluate` script is used to evaluate the performance of the PRS predictions using the PRS computed in 5 | the previous step. The script provides a variety of options for the user to customize the evaluation process, 6 | including the choice of performance metrics and the choice of evaluation datasets. 7 | 8 | A full listing of the options available for the `viprs_evaluate` script can be found by running the 9 | following command in your terminal: 10 | 11 | ```bash 12 | viprs_evaluate -h 13 | ``` 14 | 15 | Which outputs the following help message: 16 | 17 | ```bash 18 | 19 |           ********************************************** 20 |                    _____                           21 |            ___   _____(_)________ ________________ 22 |            __ | / /__  / ___  __ \__  ___/__  ___/ 23 |            __ |/ / _  /  __  /_/ /_  /    _(__  )  24 |            _____/  /_/   _  .___/ /_/     /____/   25 |                          /_/                       26 |                 27 |           Variational Inference of Polygenic Risk Scores 28 |            Version: 0.1.3 | Release date: April 2025 29 |            Author: Shadi Zabad, McGill University 30 |           ********************************************** 31 |           < Evaluate Prediction Accuracy of PRS Models > 32 | 33 | usage: viprs_evaluate [-h] --prs-file PRS_FILE --phenotype-file PHENO_FILE [--phenotype-col PHENO_COL] 34 | [--phenotype-likelihood {binomial,gaussian,infer}] [--keep KEEP] --output-file OUTPUT_FILE 35 | [--metrics METRICS [METRICS ...]] [--covariates-file COVARIATES_FILE] 36 | [--log-level {CRITICAL,WARNING,INFO,DEBUG,ERROR}] 37 | 38 | Commandline arguments for evaluating polygenic scores 39 | 40 | optional arguments: 41 | -h, --help show this help message and exit 42 | --prs-file PRS_FILE The path to the PRS file (expected format: FID IID PRS, tab-separated) 43 | --phenotype-file PHENO_FILE 44 | The path to the phenotype file. The expected format is: FID IID phenotype (no header), tab-separated. 45 | --phenotype-col PHENO_COL 46 | The column index for the phenotype in the phenotype file (0-based index). 47 | --phenotype-likelihood {binomial,gaussian,infer} 48 | The phenotype likelihood ("gaussian" for continuous, "binomial" for case-control). If not set, will be inferred 49 | automatically based on the phenotype file. 50 | --keep KEEP A plink-style keep file to select a subset of individuals for the evaluation. 51 | --output-file OUTPUT_FILE 52 | The output file where to store the evaluation metrics (with no extension). 53 | --metrics METRICS [METRICS ...] 54 | The evaluation metrics to compute (default: all available metrics that are relevant for the phenotype). For a full 55 | list of supported metrics, check the documentation. 56 | --covariates-file COVARIATES_FILE 57 | A file with covariates for the samples included in the analysis. This tab-separated file should not have a header 58 | and the first two columns should be the FID and IID of the samples. 59 | --log-level {CRITICAL,WARNING,INFO,DEBUG,ERROR} 60 | The logging level for the console output. 61 | 62 | ``` -------------------------------------------------------------------------------- /docs/download_ld.md: -------------------------------------------------------------------------------- 1 | Linkage-Disequilibrium (LD) matrices, which record pairwise correlations between 2 | genetic variants, are required as input to the `VIPRS` model. To facilitate running the model 3 | on GWAS data from diverse ancestries, we computed LD matrices for 6 continental populations represented in 4 | the UK Biobank. The six ancestry groups and their corresponding download links are listed below: 5 | 6 | | Code | Ancestry group | Sample size | Download | 7 | |:-----:|:--------------------|:-----------:|:--------------------------------------------------------------------------------------------------------------------------------------------------------:| 8 | | `EUR` | European | 362446 | [GitHub](https://github.com/shz9/viprs/releases/download/v0.1.2/EUR.tar.gz) or [Zenodo](https://zenodo.org/records/14614207/files/EUR.tar.gz?download=1) | 9 | | `CSA` | Central/South Asian | 8284 | [GitHub](https://github.com/shz9/viprs/releases/download/v0.1.2/CSA.tar.gz) or [Zenodo](https://zenodo.org/records/14614207/files/CSA.tar.gz?download=1) | 10 | | `AFR` | African | 6255 | [GitHub](https://github.com/shz9/viprs/releases/download/v0.1.2/AFR.tar.gz) or [Zenodo](https://zenodo.org/records/14614207/files/AFR.tar.gz?download=1) | 11 | | `EAS` | East Asian | 2700 | [GitHub](https://github.com/shz9/viprs/releases/download/v0.1.2/EAS.tar.gz) or [Zenodo](https://zenodo.org/records/14614207/files/EAS.tar.gz?download=1) | 12 | | `MID` | Middle Eastern | 1567 | [GitHub](https://github.com/shz9/viprs/releases/download/v0.1.2/MID.tar.gz) or [Zenodo](https://zenodo.org/records/14614207/files/MID.tar.gz?download=1) | 13 | | `AMR` | Admixed American | 987 | [GitHub](https://github.com/shz9/viprs/releases/download/v0.1.2/AMR.tar.gz) or [Zenodo](https://zenodo.org/records/14614207/files/AMR.tar.gz?download=1) | 14 | 15 | 16 | The sample sizes here are restricted to unrelated individuals in the UK Biobank. 17 | 18 | The matrices were computed using the `block` LD estimator, where we only record pairwise correlations between 19 | variants in the same LD block. The LD blocks are defined by [`LDetect`](https://bitbucket.org/nygcresearch/ldetect-data/src/master/). 20 | The matrices were computed using the sister package [`magenpy`](https://shz9.github.io/magenpy/) and were then 21 | quantized to `int8` data type for enhanced compressibility. 22 | 23 | For European samples, we also provide LD matrices that record pairwise correlations for up to 18 million variants. 24 | This matrix is available for download via [Zenodo](https://zenodo.org/records/14614207). 25 | 26 | For more details on QC criteria, data preparation, etc., please consult our manuscript: 27 | 28 | >Zabad et al. (2025). Towards whole-genome inference of polygenic scores with fast and memory-efficient algorithms. 29 | > BioRxiv. 30 | 31 | 32 | To access and use these matrices for downstream tasks, consult the codebase of [`magenpy`](https://shz9.github.io/magenpy/), our 33 | sister python package that implements specialized data structures for computing and processing large-scale LD matrices. 34 | 35 | ## Bash Script for downloading/extracting LD matrices 36 | 37 | Here is a bash script that can be used to download and extract the LD matrices for all 6 populations. The script uses 38 | the `GitHub` links provided above. Feel free to modify the script to suit your needs. 39 | 40 | ```bash 41 | #!/bin/bash 42 | output_dir="LD_matrices" 43 | populations=("EUR" "CSA" "AFR" "EAS" "MID" "AMR") 44 | extract=true 45 | 46 | mkdir -p $output_dir 47 | 48 | for pop in "${populations[@]}" 49 | do 50 | echo "Downloading LD matrix for $pop" 51 | wget -O $output_dir/$pop.tar.gz "https://github.com/shz9/viprs/releases/download/v0.1.2/$pop.tar.gz" 52 | if [ "$extract" = true ]; then 53 | mkdir -p $output_dir/$pop 54 | tar -xf $output_dir/$pop.tar.gz -C $output_dir/$pop 55 | fi 56 | done 57 | ``` 58 | 59 | -------------------------------------------------------------------------------- /viprs/model/LDPredInf.py: -------------------------------------------------------------------------------- 1 | from .BayesPRSModel import BayesPRSModel 2 | 3 | 4 | class LDPredInf(BayesPRSModel): 5 | """ 6 | A wrapper class implementing the LDPred-inf model. 7 | The LDPred-inf model is a Bayesian model that uses summary statistics 8 | from GWAS to estimate the posterior mean effect sizes of the SNPs. It is equivalent 9 | to performing ridge regression, with the penalty proportional to the inverse of 10 | the per-SNP heritability. 11 | 12 | Refer to the following references for details about the LDPred-inf model: 13 | * Vilhjálmsson et al. AJHG. 2015 14 | * Privé et al. Bioinformatics. 2020 15 | 16 | :ivar gdl: An instance of `GWADataLoader` 17 | :ivar h2: The heritability for the trait (can also be chromosome-specific) 18 | 19 | """ 20 | 21 | def __init__(self, 22 | gdl, 23 | h2=None): 24 | """ 25 | Initialize the LDPred-inf model. 26 | :param gdl: An instance of GWADataLoader 27 | :param h2: The heritability for the trait (can also be chromosome-specific) 28 | """ 29 | super().__init__(gdl) 30 | 31 | if h2 is None: 32 | from magenpy.stats.h2.ldsc import simple_ldsc 33 | self.h2 = simple_ldsc(self.gdl) 34 | else: 35 | self.h2 = h2 36 | 37 | def get_heritability(self): 38 | """ 39 | :return: The heritability estimate for the trait of interest. 40 | """ 41 | return self.h2 42 | 43 | def fit(self, solver='minres', **solver_kwargs): 44 | """ 45 | Fit the summary statistics-based ridge regression, 46 | following the specifications of the LDPred-inf model. 47 | 48 | !!! warning 49 | Not tested yet. 50 | 51 | Here, we use `lsqr` or `minres` solvers to solve the system of equations: 52 | 53 | (D + lam*I)BETA = BETA_HAT 54 | 55 | where D is the LD matrix, BETA is ridge regression 56 | estimate that we wish to obtain and BETA_HAT is the 57 | marginal effect sizes estimated from GWAS. 58 | 59 | In this case, lam = M / N*h2, where M is the number of SNPs, 60 | N is the number of samples and h2 is the heritability 61 | of the trait. 62 | 63 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html 64 | https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.minres.html 65 | 66 | :param solver: The solver for the system of linear equations. Options: `minres` or `lsqr` 67 | :param solver_kwargs: keyword arguments for the solver. 68 | """ 69 | 70 | assert solver in ('lsqr', 'minres') 71 | 72 | import numpy as np 73 | from scipy.sparse.linalg import lsqr, minres 74 | from scipy.sparse import identity, block_diag 75 | 76 | if solver == 'lsqr': 77 | solve = lsqr 78 | else: 79 | solve = minres 80 | 81 | # Lambda, the regularization parameter for the 82 | # ridge regression estimator. For LDPred-inf model, 83 | # we set this to 'M / N*h2', where M is the number of SNPs, 84 | # N is the number of samples and h2 is the heritability 85 | # of the trait. 86 | lam = self.n_snps / (self.n * self.h2) 87 | 88 | chroms = self.gdl.chromosomes 89 | 90 | # Extract the LD matrices for all the chromosomes represented and 91 | # concatenate them into one block diagonal matrix: 92 | ld_mats = [] 93 | for c in chroms: 94 | self.gdl.ld[c].load(dtype=np.float32) 95 | ld_mats.append(self.gdl.ld[c].csr_matrix) 96 | 97 | ld = block_diag(ld_mats, format='csr') 98 | 99 | # Extract the marginal GWAS effect sizes: 100 | marginal_beta = np.concatenate([self.gdl.sumstats_table[c].marginal_beta 101 | for c in chroms]) 102 | 103 | # Estimate the BETAs under the ridge penalty: 104 | res = solve(ld + lam * identity(ld.shape[0]), marginal_beta, **solver_kwargs) 105 | 106 | # Extract the estimates and populate them in `post_mean_beta` 107 | start = 0 108 | self.post_mean_beta = {} 109 | 110 | for c in chroms: 111 | self.post_mean_beta[c] = res[0][start:start + self.shapes[c]] 112 | start += self.shapes[c] 113 | 114 | return self 115 | -------------------------------------------------------------------------------- /viprs/eval/continuous_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from .eval_utils import fit_linear_model 4 | 5 | 6 | def r2(true_val, pred_val): 7 | """ 8 | Compute the R^2 (proportion of variance explained) between 9 | the predictions or PRS `pred_val` and the phenotype `true_val` 10 | 11 | :param true_val: The response value or phenotype (a numpy vector) 12 | :param pred_val: The predicted value or PRS (a numpy vector) 13 | 14 | :return: The R^2 value 15 | """ 16 | from scipy import stats 17 | 18 | _, _, r_val, _, _ = stats.linregress(pred_val, true_val) 19 | return r_val ** 2 20 | 21 | 22 | def mse(true_val, pred_val): 23 | """ 24 | Compute the mean squared error (MSE) between 25 | the predictions or PRS `pred_val` and the phenotype `true_val` 26 | 27 | :param true_val: The response value or phenotype (a numpy vector) 28 | :param pred_val: The predicted value or PRS (a numpy vector) 29 | 30 | :return: The mean squared error 31 | """ 32 | 33 | return np.mean((pred_val - true_val)**2) 34 | 35 | 36 | def spearman_r(true_val, pred_val): 37 | """ 38 | Compute the spearman correlation between the predictions or PRS `pred_val` and the phenotype `true_val` 39 | 40 | :param true_val: The response value or phenotype (a numpy vector) 41 | :param pred_val: The predicted value or PRS (a numpy vector) 42 | :return: The spearman correlation 43 | """ 44 | 45 | from scipy import stats 46 | return stats.spearmanr(true_val, pred_val).statistic 47 | 48 | 49 | def pearson_r(true_val, pred_val): 50 | """ 51 | Compute the pearson correlation coefficient between 52 | the predictions or PRS `pred_val` and the phenotype `true_val` 53 | 54 | :param true_val: The response value or phenotype (a numpy vector) 55 | :param pred_val: The predicted value or PRS (a numpy vector) 56 | 57 | :return: The pearson correlation coefficient 58 | """ 59 | return np.corrcoef(true_val, pred_val)[0, 1] 60 | 61 | 62 | def r2_residualized_target(true_val, pred_val, covariates): 63 | """ 64 | Compute the R^2 (proportion of variance explained) between 65 | the predictions or PRS `pred_val` and the phenotype `true_val` 66 | after residualizing the phenotype on a set of covariates. 67 | 68 | :param true_val: The response value or phenotype (a numpy vector) 69 | :param pred_val: The predicted value or PRS (a numpy vector) 70 | :param covariates: A pandas table of covariates where the rows are ordered 71 | the same way as the predictions and response. 72 | 73 | :return: The residualized R^2 value 74 | """ 75 | 76 | resid_true_val = fit_linear_model(true_val, covariates, add_intercept=True) 77 | 78 | return r2(resid_true_val.resid, pred_val) 79 | 80 | 81 | def incremental_r2(true_val, pred_val, covariates=None, return_all_r2=False): 82 | """ 83 | Compute the incremental prediction R^2 (proportion of phenotypic variance explained by the PRS). 84 | This metric is computed by taking the R^2 of a model with covariates+PRS and subtracting from it 85 | the R^2 of a model with covariates alone covariates. 86 | 87 | :param true_val: The response value or phenotype (a numpy vector) 88 | :param pred_val: The predicted value or PRS (a numpy vector) 89 | :param covariates: A pandas table of covariates where the rows are ordered 90 | the same way as the predictions and response. 91 | :param return_all_r2: If True, return the R^2 values for the null and full models as well. 92 | 93 | :return: The incremental R^2 value 94 | """ 95 | 96 | if covariates is None: 97 | add_intercept = False 98 | covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) 99 | else: 100 | add_intercept = True 101 | 102 | null_result = fit_linear_model(true_val, covariates, add_intercept=add_intercept) 103 | full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), 104 | add_intercept=add_intercept) 105 | 106 | if return_all_r2: 107 | return { 108 | 'Null_R2': null_result.rsquared, 109 | 'Full_R2': full_result.rsquared, 110 | 'Incremental_R2': full_result.rsquared - null_result.rsquared 111 | } 112 | else: 113 | return full_result.rsquared - null_result.rsquared 114 | 115 | 116 | def partial_correlation(true_val, pred_val, covariates): 117 | """ 118 | Compute the partial correlation between the phenotype `true_val` and the PRS `pred_val` 119 | by conditioning on a set of covariates. This metric is computed by first residualizing the 120 | phenotype and the PRS on a set of covariates and then computing the correlation coefficient 121 | between the residuals. 122 | 123 | :param true_val: The response value or phenotype (a numpy vector) 124 | :param pred_val: The predicted value or PRS (a numpy vector) 125 | :param covariates: A pandas table of covariates where the rows are ordered 126 | the same way as the predictions and response. 127 | 128 | :return: The partial correlation coefficient 129 | """ 130 | 131 | true_response = fit_linear_model(true_val, covariates, add_intercept=True) 132 | pred_response = fit_linear_model(pred_val, covariates, add_intercept=True) 133 | 134 | return np.corrcoef(true_response.resid, pred_response.resid)[0, 1] 135 | -------------------------------------------------------------------------------- /docs/getting_started.md: -------------------------------------------------------------------------------- 1 | `viprs` is a `python` package for fitting Bayesian Polygenic Risk Score (PRS) models to summary statistics 2 | derived from Genome-wide Association Studies (GWASs). To showcase the interfaces and functionalities of the package 3 | as well as the data structures that power it, we will start with a simple example. 4 | 5 | !!! note 6 | This example is designed to highlight the features of the package and the python API. If you'd like to 7 | use the commandline interface, please refer to the [Command Line Scripts](commandline/overview.md) documentation. 8 | 9 | Generally, summary statistics-based PRS methods require access to: 10 | 11 | * GWAS summary statistics for the trait of interest 12 | * Linkage-Disequilibrium (LD) matrices from an appropriately-matched reference panel (e.g. 13 | from the 1KG dataset or UK Biobank). 14 | 15 | For the first item, we will use summary statistics for Standing Height (`EFO_0004339`) from the `fastGWA` 16 | [catalogue](https://yanglab.westlake.edu.cn/data/ukb_fastgwa/imp/pheno/50). 17 | For the second item, we will use genotype data on chromosome 22 for a subset of 378 European samples from the 18 | 1KG project. This small dataset is shipped with the python package `magenpy`. 19 | 20 | To start, let's import the required `python` packages: 21 | 22 | ```python linenums="1" 23 | import magenpy as mgp 24 | import viprs as vp 25 | ``` 26 | 27 | Then, we will use `magenpy` to read the 1KG genotype dataset and *automatically* match it with the GWAS 28 | summary statistics from `fastGWA`: 29 | 30 | ```python linenums="1" 31 | # Load genotype and GWAS summary statistics data (chromosome 22): 32 | gdl = mgp.GWADataLoader(bed_files=mgp.tgp_eur_data_path(), # Path of the genotype data 33 | sumstats_files=mgp.ukb_height_sumstats_path(), # Path of the summary statistics 34 | sumstats_format="fastGWA") # Specify the format of the summary statistics 35 | ``` 36 | 37 | Once the genotype and summary statistics data are read by `magenpy`, we can go ahead and compute 38 | the LD (or SNP-by-SNP correlation) matrix: 39 | 40 | ```python linenums="1" 41 | # Compute LD using the shrinkage estimator (Wen and Stephens 2010): 42 | gdl.compute_ld("shrinkage", 43 | output_dir="temp", 44 | genetic_map_ne=11400, # effective population size (Ne) 45 | genetic_map_sample_size=183, 46 | threshold=1e-3) 47 | ``` 48 | 49 | Because of the small sample size of the reference panel, here we recommend using the `shrinkage` estimator 50 | for LD from Wen and Stephens (2010). The shrinkage estimator results in compact and sparse LD matrices that are 51 | more robust than the sample LD. The estimator requires access to information about the genetic map, such as 52 | the position of each SNP in centi Morgan, the effective population size, and the sample size used to 53 | estimate the genetic map. 54 | 55 | Given the LD information from the reference panel, we can next fit the VIPRS model to the summary statistics data: 56 | 57 | ```python linenums="1" 58 | # Initialize VIPRS, passing it the GWADataLoader object 59 | v = vp.VIPRS(gdl) 60 | # Invoke the .fit() method to obtain posterior estimates 61 | v.fit() 62 | ``` 63 | 64 | Once the model converges, we can generate PRS estimates for height for the 1KG samples by simply 65 | invoking the `.predict()` method: 66 | 67 | ```python linenums="1" 68 | v.predict() 69 | ``` 70 | 71 | ``` 72 | array([ 0.01944202, 0.00597704, 0.07329462, ..., 0.06666187, 0.05251297, 0.00359018]) 73 | ``` 74 | These are the polygenic scores for height for the European samples in the 1KG dataset! 75 | 76 | To examine posterior estimates for the model parameters, you can simply invoke the `.to_table()` method: 77 | 78 | ```python linenums="1" 79 | v.to_table() 80 | ``` 81 | 82 | ``` 83 | CHR SNP A1 A2 PIP BETA VAR_BETA 84 | 0 22 rs131538 A G 0.006107 -5.955517e-06 1.874619e-08 85 | 1 22 rs9605903 C T 0.005927 5.527188e-06 1.774252e-08 86 | 2 22 rs5746647 G T 0.005015 1.194178e-07 1.120063e-08 87 | 3 22 rs16980739 T C 0.008331 -1.335695e-05 3.717944e-08 88 | 4 22 rs9605923 A T 0.006181 6.334971e-06 1.979157e-08 89 | ... ... ... .. .. ... ... ... 90 | 15930 22 rs8137951 A G 0.006367 -6.880591e-06 2.059650e-08 91 | 15931 22 rs2301584 A G 0.179406 -7.234545e-04 2.597197e-06 92 | 15932 22 rs3810648 G A 0.008000 1.222151e-05 3.399927e-08 93 | 15933 22 rs2285395 A G 0.005356 3.004282e-06 1.349082e-08 94 | 15934 22 rs28729663 A G 0.005350 -2.781053e-06 1.351239e-08 95 | 96 | [15935 rows x 7 columns] 97 | ``` 98 | 99 | Here, `PIP` is the **P**osterior **I**nclusion **P**robability under the variational density, while 100 | `BETA` and `VAR_BETA` are the posterior mean and variance for the effect size, respectively. 101 | For the purposes of prediction, we only need the `BETA` column. You can also examine the 102 | inferred hyperparameters of the model by invoking the `.to_theta_table()` method: 103 | 104 | ```python linenums="1" 105 | v.to_theta_table() 106 | ``` 107 | 108 | ``` 109 | Parameter Value 110 | 0 Residual_variance 0.994231 111 | 1 Heritability 0.005736 112 | 2 Proportion_causal 0.015887 113 | 3 sigma_beta 0.000021 114 | ``` 115 | 116 | Note that here, the SNP heritability only considers the contribution of variants on 117 | chromosome 22. -------------------------------------------------------------------------------- /viprs/utils/OptimizeResult.py: -------------------------------------------------------------------------------- 1 | 2 | class IterationConditionCounter(object): 3 | """ 4 | A class to keep track of the number of (consecutive) iterations that a condition has been met. 5 | 6 | :ivar _counter: The number of consecutive iterations that the condition has been met. 7 | :ivar _nit: The current iteration number. 8 | """ 9 | 10 | def __init__(self): 11 | """ 12 | Initialize the counter. 13 | """ 14 | self._counter = 0 15 | self._nit = 0 16 | 17 | @property 18 | def counter(self): 19 | """ 20 | :return: The number of consecutive iterations that the condition has been met. 21 | """ 22 | return self._counter 23 | 24 | def update(self, condition, iteration): 25 | """ 26 | Update the counter based on the condition. 27 | :param condition: The condition to check 28 | :param iteration: The current iteration 29 | """ 30 | if condition and (iteration == self._nit + 1): 31 | self._counter += 1 32 | else: 33 | self._counter = 0 34 | 35 | self._nit = iteration 36 | 37 | 38 | class OptimizeResult(object): 39 | """ 40 | A class to store the results/progress of an optimization algorithm. 41 | Similar to the `OptimizeResult` class from `scipy.optimize`, 42 | but with a few additional fields and parameters. 43 | 44 | :ivar message: A message about the optimization result 45 | :ivar stop_iteration: A flag to indicate whether the optimization algorithm has stopped iterating 46 | :ivar success: A flag to indicate whether the optimization algorithm has succeeded 47 | :ivar fun: The current objective function value 48 | :ivar nit: The current number of iterations 49 | :ivar error_on_termination: A flag to indicate whether the optimization algorithm stopped due to an error. 50 | """ 51 | 52 | def __init__(self): 53 | 54 | self.message = None 55 | self.stop_iteration = None 56 | self.success = None 57 | self.fun = None 58 | self.nit = 0 59 | self.error_on_termination = False 60 | 61 | self._last_drop_iter = None 62 | self._oscillation_counter = 0 63 | 64 | @property 65 | def iterations(self): 66 | """ 67 | :return: The current number of iterations. 68 | """ 69 | return self.nit 70 | 71 | @property 72 | def objective(self): 73 | """ 74 | :return: The current value for the objective function. 75 | """ 76 | return self.fun 77 | 78 | @property 79 | def converged(self): 80 | """ 81 | :return: The flag indicating whether the optimization algorithm has converged. 82 | """ 83 | return self.success 84 | 85 | @property 86 | def valid_optim_result(self): 87 | """ 88 | :return: Boolean flag indicating whether the optimization result is valid in 89 | the sense tht it either successfully converged OR it stopped iterating without 90 | an error (due to e.g. reaching maximum number of iterations). 91 | """ 92 | return self.success or (self.stop_iteration and not self.error_on_termination) 93 | 94 | @property 95 | def oscillation_counter(self): 96 | """ 97 | :return: The number of oscillations in the objective function value. 98 | """ 99 | return self._oscillation_counter 100 | 101 | def reset(self): 102 | """ 103 | Reset the stored values to their initial state. 104 | """ 105 | 106 | self.message = None 107 | self.stop_iteration = False 108 | self.success = False 109 | self.fun = None 110 | self.nit = 0 111 | self.error_on_termination = False 112 | self._last_drop_iter = None 113 | self._oscillation_counter = 0 114 | 115 | def _reset_oscillation_counter(self): 116 | """ 117 | Reset the oscillation counter. 118 | """ 119 | self._oscillation_counter = 0 120 | 121 | def update(self, fun, stop_iteration=False, success=False, message=None, increment=True): 122 | """ 123 | Update the stored values with new values. 124 | :param fun: The new objective function value 125 | :param stop_iteration: A flag to indicate whether the optimization algorithm has stopped iterating 126 | :param success: A flag to indicate whether the optimization algorithm has succeeded 127 | :param message: A detailed message about the optimization result. 128 | :param increment: A flag to indicate whether to increment the number of iterations. 129 | """ 130 | 131 | # If there's a drop in the objective, start tracking potential oscillations: 132 | if self.fun is not None and fun < self.fun: 133 | if self._last_drop_iter is not None and self.nit - self._last_drop_iter == 1: 134 | self._oscillation_counter += 1 135 | 136 | self._last_drop_iter = self.nit + 1 137 | elif self._last_drop_iter is not None and self.nit > self._last_drop_iter: 138 | # If there's no drop and the last drop is more than 2 iteration ago, 139 | # then reset the oscillation counter 140 | self._reset_oscillation_counter() 141 | 142 | self.fun = fun 143 | self.stop_iteration = stop_iteration 144 | self.success = success 145 | self.message = message 146 | 147 | self.nit += int(increment) 148 | 149 | if stop_iteration and not success and "Maximum iterations" not in message: 150 | self.error_on_termination = True 151 | 152 | def __str__(self): 153 | return str(self.__dict__) 154 | -------------------------------------------------------------------------------- /viprs/utils/math_utils.pyx: -------------------------------------------------------------------------------- 1 | # cython: linetrace=False 2 | # cython: profile=False 3 | # cython: binding=False 4 | # cython: boundscheck=False 5 | # cython: wraparound=False 6 | # cython: initializedcheck=False 7 | # cython: nonecheck=False 8 | # cython: language_level=3 9 | # cython: infer_types=True 10 | 11 | cimport cython 12 | import numpy as np 13 | from cython cimport floating 14 | from libc.math cimport exp, log 15 | from scipy.linalg.cython_blas cimport saxpy, daxpy, sdot, ddot 16 | 17 | 18 | # ------------------------------------------------------------ 19 | # BLAS implementations of some linear algebra operations 20 | 21 | @cython.boundscheck(False) 22 | @cython.wraparound(False) 23 | @cython.nonecheck(False) 24 | @cython.exceptval(check=False) 25 | cdef void scipy_blas_axpy(floating[::1] v1, floating[::1] v2, floating alpha) noexcept nogil: 26 | """v1 := v1 + alpha * v2""" 27 | cdef: 28 | int inc = 1, n=v1.shape[0] 29 | 30 | if floating is float: 31 | saxpy(&n, &alpha, &v2[0], &inc, &v1[0], &inc) 32 | else: 33 | daxpy(&n, &alpha, &v2[0], &inc, &v1[0], &inc) 34 | 35 | @cython.boundscheck(False) 36 | @cython.wraparound(False) 37 | @cython.nonecheck(False) 38 | @cython.exceptval(check=False) 39 | cdef floating scipy_blas_dot(floating[::1] v1, floating[::1] v2) noexcept nogil: 40 | """v1 . v2""" 41 | cdef: 42 | int inc = 1, n=v1.shape[0] 43 | 44 | if floating is float: 45 | return sdot(&n, &v1[0], &inc, &v2[0], &inc) 46 | else: 47 | return ddot(&n, &v1[0], &inc, &v2[0], &inc) 48 | 49 | # ------------------------------------------------------------ 50 | 51 | @cython.boundscheck(False) 52 | @cython.wraparound(False) 53 | @cython.nonecheck(False) 54 | @cython.cdivision(True) 55 | @cython.exceptval(check=False) 56 | cdef floating[::1] softmax(floating[::1] x) noexcept nogil: 57 | """ 58 | A numerically stable implementation of softmax 59 | """ 60 | 61 | cdef unsigned int i, end = x.shape[0] 62 | cdef floating s = 0., max_x = c_max(x) 63 | 64 | with nogil: 65 | for i in range(end): 66 | x[i] = exp(x[i] - max_x) 67 | s += x[i] 68 | 69 | for i in range(end): 70 | x[i] /= s 71 | 72 | return x 73 | 74 | @cython.boundscheck(False) 75 | @cython.wraparound(False) 76 | @cython.nonecheck(False) 77 | @cython.cdivision(True) 78 | @cython.exceptval(check=False) 79 | cdef floating sigmoid(floating x) noexcept nogil: 80 | """ 81 | A numerically stable version of the Sigmoid function. 82 | """ 83 | if x < 0: 84 | exp_x = exp(x) 85 | return exp_x / (1. + exp_x) 86 | else: 87 | return 1. / (1. + exp(-x)) 88 | 89 | 90 | @cython.boundscheck(False) 91 | @cython.wraparound(False) 92 | @cython.nonecheck(False) 93 | @cython.cdivision(True) 94 | @cython.exceptval(check=False) 95 | cdef floating logit(floating x) noexcept nogil: 96 | """ 97 | The logit function (inverse of the sigmoid function) 98 | """ 99 | return log(x / (1. - x)) 100 | 101 | 102 | @cython.boundscheck(False) 103 | @cython.wraparound(False) 104 | @cython.nonecheck(False) 105 | @cython.exceptval(check=False) 106 | cdef floating dot(floating[::1] v1, floating[::1] v2) noexcept nogil: 107 | """ 108 | Dot product between vectors of the same shape 109 | """ 110 | 111 | cdef unsigned int i, end = v1.shape[0] 112 | cdef floating s = 0. 113 | 114 | with nogil: 115 | for i in range(end): 116 | s += v1[i]*v2[i] 117 | 118 | return s 119 | 120 | @cython.boundscheck(False) 121 | @cython.wraparound(False) 122 | @cython.nonecheck(False) 123 | @cython.exceptval(check=False) 124 | cdef floating vec_sum(floating[::1] v1) noexcept nogil: 125 | """ 126 | Vector summation 127 | """ 128 | 129 | cdef unsigned int i, end = v1.shape[0] 130 | cdef floating s = 0. 131 | 132 | with nogil: 133 | for i in range(end): 134 | s += v1[i] 135 | 136 | return s 137 | 138 | @cython.boundscheck(False) 139 | @cython.wraparound(False) 140 | @cython.nonecheck(False) 141 | @cython.exceptval(check=False) 142 | cdef void axpy(floating[::1] v1, floating[::1] v2, floating s) noexcept nogil: 143 | """ 144 | Elementwise addition and multiplication 145 | """ 146 | 147 | cdef unsigned int i, end = v1.shape[0] 148 | 149 | with nogil: 150 | for i in range(end): 151 | v1[i] = v1[i] + v2[i] * s 152 | 153 | @cython.boundscheck(False) 154 | @cython.wraparound(False) 155 | @cython.nonecheck(False) 156 | @cython.cdivision(True) 157 | @cython.exceptval(check=False) 158 | cdef floating[::1] clip_list(floating[::1] a, floating min_value, floating max_value) noexcept nogil: 159 | """ 160 | Iterate over a list and clip every element to be between `min_value` and `max_value` 161 | :param a: A list of floating point numbers 162 | :param min_value: Minimum values 163 | :param max_value: Maximum value 164 | """ 165 | 166 | cdef unsigned int i, end = a.shape[0] 167 | 168 | with nogil: 169 | for i in range(end): 170 | a[i] = clip(a[i], min_value, max_value) 171 | 172 | return a 173 | 174 | @cython.boundscheck(False) 175 | @cython.wraparound(False) 176 | @cython.nonecheck(False) 177 | @cython.cdivision(True) 178 | @cython.exceptval(check=False) 179 | cdef floating c_max(floating[::1] x) noexcept nogil: 180 | """ 181 | Obtain the maximum value in a vector `x` 182 | """ 183 | cdef unsigned int i, end = x.shape[0] 184 | cdef floating current_max = 0. 185 | 186 | with nogil: 187 | for i in range(end): 188 | if i == 0 or current_max < x[i]: 189 | current_max = x[i] 190 | 191 | return current_max 192 | 193 | @cython.boundscheck(False) 194 | @cython.wraparound(False) 195 | @cython.nonecheck(False) 196 | @cython.exceptval(check=False) 197 | cdef floating clip(floating a, floating min_value, floating max_value) noexcept nogil: 198 | """ 199 | Clip a scalar value `a` to be between `min_value` and `max_value` 200 | """ 201 | 202 | if a < min_value: 203 | a = min_value 204 | if a > max_value: 205 | a = max_value 206 | 207 | return a 208 | 209 | def bernoulli_entropy(p): 210 | """ 211 | Compute the entropy of a Bernoulli variable given a vector of probabilities. 212 | :param p: A vector (or scalar) of probabilities between zero and one, 0. < p < 1. 213 | """ 214 | return -(p*np.log(p) + (1. - p)*np.log(1. - p)) 215 | -------------------------------------------------------------------------------- /viprs/utils/compute_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import psutil 4 | 5 | 6 | def fits_in_memory(alloc_size, max_prop=.9): 7 | """ 8 | Check whether there's enough memory resources to load an object 9 | with the given allocation size (in MB). 10 | :param alloc_size: The allocation size 11 | :param max_prop: The maximum proportion of available memory allowed for the object 12 | """ 13 | 14 | avail_mem = psutil.virtual_memory().available / (1024.0 ** 2) 15 | 16 | if alloc_size / avail_mem > max_prop: 17 | return False 18 | else: 19 | return True 20 | 21 | 22 | def dict_concat(d, axis=0): 23 | """ 24 | Concatenate the values of a dictionary into a single vector 25 | :param d: A dictionary where values are numeric scalars or vectors 26 | :param axis: Concatenate along given axis. 27 | """ 28 | if len(d) == 1: 29 | return d[next(iter(d))] 30 | else: 31 | return np.concatenate([d[c] for c in sorted(d.keys())], axis=axis) 32 | 33 | 34 | def dict_max(d, axis=None): 35 | """ 36 | Estimate the maximum of the values of a dictionary 37 | :param d: A dictionary where values are numeric scalars or vectors 38 | :param axis: Perform aggregation along given axis. 39 | """ 40 | return np.max(np.array([np.max(v, axis=axis) for v in d.values()]), axis=axis) 41 | 42 | 43 | def dict_mean(d, axis=None): 44 | """ 45 | Estimate the mean of the values of a dictionary 46 | :param d: A dictionary where values are numeric scalars or vectors 47 | :param axis: Perform aggregation along given axis. 48 | """ 49 | return np.mean(np.array([np.mean(v, axis=axis) for v in d.values()]), axis=axis) 50 | 51 | 52 | def dict_sum(d, axis=None, transform=None): 53 | """ 54 | Estimate the sum of the values of a dictionary 55 | :param d: A dictionary where values are numeric scalars or vectors 56 | :param axis: Perform aggregation along given axis. 57 | :param transform: Transformation to apply before summing. 58 | """ 59 | if transform is None: 60 | return np.sum(np.array([np.sum(v, axis=axis) for v in d.values()]), axis=axis) 61 | else: 62 | return np.sum(np.array([np.sum(transform(v), axis=axis) for v in d.values()]), axis=axis) 63 | 64 | 65 | def dict_elementwise_transform(d, transform): 66 | """ 67 | Apply a transformation to values of a dictionary 68 | :param d: A dictionary where values are numeric scalars or vectors 69 | :param transform: A function to apply to 70 | """ 71 | return {c: np.vectorize(transform)(v) for c, v in d.items()} 72 | 73 | 74 | def dict_elementwise_dot(d1, d2): 75 | """ 76 | Apply element-wise product between the values of two dictionaries 77 | 78 | :param d1: A dictionary where values are numeric scalars or vectors 79 | :param d2: A dictionary where values are numeric scalars or vectors 80 | """ 81 | return {c: d1[c]*d2[c] for c, v in d1.items()} 82 | 83 | 84 | def dict_dot(d1, d2): 85 | """ 86 | Perform dot product on the elements of d1 and d2 87 | :param d1: A dictionary where values are numeric scalars or vectors 88 | :param d2: A dictionary where values are numeric scalars or vectors 89 | """ 90 | return np.sum([np.dot(d1[c], d2[c]) for c in d1.keys()]) 91 | 92 | 93 | def dict_set(d, value): 94 | """ 95 | :param d: A dictionary where values are numeric vectors 96 | :param value: A value to set for all vectors 97 | """ 98 | for c in d: 99 | d[c][:] = value 100 | 101 | return d 102 | 103 | 104 | def dict_repeat(value, shapes): 105 | """ 106 | Given a value, create a dictionary where the value is repeated 107 | according to the shapes parameter 108 | :param shapes: A dictionary of shapes. Key is arbitrary, value is integer input to np.repeat 109 | :param value: The value to repeat 110 | """ 111 | return {c: value*np.ones(shp) for c, shp in shapes.items()} 112 | 113 | 114 | def expand_column_names(c_name, shape, sep='_'): 115 | """ 116 | Given a desired column name `c_name` and a matrix `shape` 117 | that we'd like to apply the column name to, return a list of 118 | column names for every column in the matrix. The column names will be 119 | in the form of `c_name` followed by an index, separated by `sep`. 120 | 121 | For example, if the column name is `BETA`, the 122 | shape is (100, 3) and the separator is `_`, we return a list with: 123 | [`BETA_0`, `BETA_1`, `BETA_2`] 124 | 125 | If the matrix in question is a vector, we just return the column name 126 | without any indices appended to it. 127 | 128 | :param c_name: A string object 129 | :param shape: The shape of a numpy matrix or vector 130 | :param sep: The separator 131 | 132 | :return: A list of column names 133 | """ 134 | 135 | if len(shape) < 2: 136 | return [c_name] 137 | elif shape[1] == 1: 138 | return [c_name] 139 | else: 140 | return [f'{c_name}{sep}{i}' for i in range(shape[1])] 141 | 142 | 143 | def combine_coefficient_tables(coef_tables, coef_col='BETA'): 144 | """ 145 | Combine a list of coefficient tables (output from a PRS model) into a single 146 | table that can be used for downstream tasks, such scoring and evaluation. Note that 147 | this implementation assumes that the coefficients tables were generated for the same 148 | set of variants, from a grid-search or similar procedure. 149 | 150 | :param coef_tables: A list of pandas dataframes containing variant information as well as 151 | inferred coefficients. 152 | :param coef_col: The name of the column containing the coefficients. 153 | :return: A single pandas dataframe with the combined coefficients. The new coefficient columns will be 154 | labelled as BETA_0, BETA_1, etc. 155 | """ 156 | 157 | # Sanity checks: 158 | assert all([coef_col in t.columns for t in coef_tables]), "All tables must contain the coefficient column." 159 | assert all([len(t) == len(coef_tables[0]) for t in coef_tables]), "All tables must have the same number of rows." 160 | 161 | if len(coef_tables) == 1: 162 | return coef_tables[0] 163 | 164 | ref_table = coef_tables[0].copy() 165 | ref_table.rename(columns={coef_col: f'{coef_col}_0'}, inplace=True) 166 | 167 | # Extract the coefficients from the other tables: 168 | return pd.concat([ref_table, *[t[[coef_col]].rename(columns={coef_col: f'{coef_col}_{i}'}) 169 | for i, t in enumerate(coef_tables[1:], 1)]], axis=1) 170 | -------------------------------------------------------------------------------- /bin/viprs_score: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Compute Polygenic Scores for Test Samples 5 | ---------------------------- 6 | 7 | This is a commandline script that computes polygenic scores for test samples 8 | given effect size estimates from VIPRS. The script can work with effect sizes from 9 | other software, as long as they're formatted in the same way as VIPRS `.fit` files. 10 | 11 | Usage: 12 | 13 | python -m viprs_score -f --bed-files --output-file 14 | 15 | - `fit_files` is the path to the file(s) with the output parameter estimates from VIPRS. 16 | - `bed_files` is the BED files containing the genotype data. 17 | - `output_file` is the output file where to store the polygenic scores (with no extension). 18 | 19 | """ 20 | 21 | # Setup the logger: 22 | import logging 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def main(): 27 | 28 | import argparse 29 | import viprs as vp 30 | 31 | print("\n" + vp.make_ascii_logo( 32 | desc='< Compute Polygenic Scores for Test Samples >', 33 | left_padding=10 34 | ) + "\n", flush=True) 35 | 36 | parser = argparse.ArgumentParser(description=""" 37 | Commandline arguments for computing polygenic scores 38 | """) 39 | 40 | parser.add_argument('-f', '--fit-files', dest='fit_files', type=str, required=True, 41 | help='The path to the file(s) with the output parameter estimates from VIPRS. ' 42 | 'You may use a wildcard here if fit files are stored ' 43 | 'per-chromosome (e.g. "prs/chr_*.fit")') 44 | parser.add_argument('--bfile', dest='bed_files', type=str, required=True, 45 | help='The BED files containing the genotype data. ' 46 | 'You may use a wildcard here (e.g. "data/chr_*.bed")') 47 | parser.add_argument('--output-file', dest='output_file', type=str, required=True, 48 | help='The output file where to store the polygenic scores (with no extension).') 49 | 50 | parser.add_argument('--temp-dir', dest='temp_dir', type=str, default='temp', 51 | help='The temporary directory where to store intermediate files.') 52 | 53 | parser.add_argument('--keep', dest='keep', type=str, 54 | help='A plink-style keep file to select a subset of individuals for the test set.') 55 | parser.add_argument('--extract', dest='extract', type=str, 56 | help='A plink-style extract file to select a subset of SNPs for scoring.') 57 | parser.add_argument('--backend', dest='backend', type=str, default='xarray', 58 | choices={'xarray', 'plink'}, 59 | help='The backend software used for computations with the genotype matrix.') 60 | parser.add_argument('--threads', dest='threads', type=int, default=1, 61 | help='The number of threads to use for computations.') 62 | parser.add_argument('--compress', dest='compress', action='store_true', default=False, 63 | help='Compress the output file') 64 | 65 | parser.add_argument('--log-level', dest='log_level', type=str, default='WARNING', 66 | choices={'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}, 67 | help='The logging level for the console output.') 68 | 69 | args = parser.parse_args() 70 | 71 | # ---------------------------------------------------------- 72 | import os.path as osp 73 | from magenpy.utils.system_utils import makedir, get_filenames, setup_logger 74 | from magenpy.GWADataLoader import GWADataLoader 75 | from viprs.model.BayesPRSModel import BayesPRSModel 76 | 77 | # ---------------------------------------------------------- 78 | # Setup the logger: 79 | 80 | # Create the output directory: 81 | makedir(osp.dirname(args.output_file)) 82 | 83 | # Clear the log file: 84 | log_file = f"{args.output_file}.log" 85 | open(log_file, 'w').close() 86 | 87 | # Set up the module loggers: 88 | setup_logger(modules=['viprs', 'magenpy'], 89 | log_file=log_file, 90 | log_level=args.log_level) 91 | 92 | # Set up the logger for the main module: 93 | setup_logger(loggers=[logger], 94 | log_file=log_file, 95 | log_format='%(message)s', 96 | log_level=['INFO', args.log_level][logging.getLevelName(args.log_level) < logging.INFO]) 97 | 98 | # ---------------------------------------------------------- 99 | 100 | logger.info('{:-^100}\n'.format(' Parsed arguments ')) 101 | 102 | for key, val in vars(args).items(): 103 | if val is not None and val != parser.get_default(key): 104 | logger.info(f"-- {key}: {val}") 105 | 106 | # ---------------------------------------------------------- 107 | logger.info('\n{:-^100}\n'.format(' Reading input data ')) 108 | 109 | test_data = GWADataLoader(args.bed_files, 110 | keep_file=args.keep, 111 | extract_file=args.extract, 112 | min_mac=None, 113 | min_maf=None, 114 | backend=args.backend, 115 | temp_dir=args.temp_dir, 116 | threads=args.threads) 117 | prs_m = BayesPRSModel(test_data) 118 | 119 | fit_files = get_filenames(args.fit_files, extension='.fit') 120 | 121 | if len(fit_files) < 1: 122 | err_msg = "Did not find PRS coefficient files at:\n" + args.fit_files 123 | logger.error(err_msg) 124 | raise FileNotFoundError(err_msg) 125 | 126 | prs_m.read_inferred_parameters(fit_files) 127 | 128 | # ---------------------------------------------------------- 129 | logger.info('\n{:-^100}\n'.format(' Scoring ')) 130 | 131 | # Predict on the test set: 132 | prs = test_data.score(prs_m.get_posterior_mean_beta()) 133 | 134 | # Save the PRS as a table: 135 | 136 | ind_table = test_data.to_individual_table().copy() 137 | ind_table['PRS'] = prs 138 | 139 | # Clean up all the intermediate files/directories 140 | test_data.cleanup() 141 | 142 | logger.info(f"\n>>> Writing the polygenic scores to:\n {osp.dirname(args.output_file)}") 143 | 144 | # If the user wants the files to be compressed, append `.gz` to the name: 145 | c_ext = ['', '.gz'][args.compress] 146 | 147 | # Output the scores: 148 | makedir(osp.dirname(args.output_file)) 149 | ind_table.to_csv(args.output_file + '.prs' + c_ext, index=False, sep="\t") 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /viprs/eval/pseudo_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _match_variant_stats(test_gdl, prs_beta_table): 5 | """ 6 | Match the standardized marginal betas from the validation set to the inferred PRS effect sizes. 7 | This function takes a `GWADataLoader` object from the validation set (with matched LD matrices and 8 | GWAS summary statistics) and a PRS table object and returns a tuple of three arrays: 9 | 10 | #. The standardized marginal betas from the validation set 11 | #. The inferred PRS effect sizes 12 | #. The LD-weighted PRS effect sizes (q) 13 | 14 | :param test_gdl: A `GWADataLoader` object from the validation or test set. 15 | :param prs_beta_table: A pandas DataFrame with the PRS effect sizes. Must contain 16 | the columns: CHR, SNP, A1, A2, BETA. 17 | 18 | :return: A tuple of three arrays: (1) The standardized marginal betas from the validation set, 19 | (2) The inferred PRS effect sizes, (3) The LD-weighted PRS effect sizes (q). 20 | """ 21 | 22 | from magenpy import GWADataLoader 23 | 24 | # Sanity checks: 25 | assert isinstance(test_gdl, GWADataLoader), "The test/validation set must be an instance of GWADataLoader." 26 | assert test_gdl.ld is not None, "The test/validation set must have LD matrices initialized." 27 | assert test_gdl.sumstats_table is not None, "The test/validation set must have summary statistics initialized." 28 | 29 | from magenpy.utils.model_utils import merge_snp_tables 30 | 31 | validation_tab = test_gdl.to_summary_statistics_table(col_subset=['CHR', 'SNP', 'A1', 'A2', 'STD_BETA'], 32 | per_chromosome=True) 33 | 34 | required_cols = ['CHR', 'SNP', 'A1', 'A2'] 35 | for col in required_cols: 36 | assert col in prs_beta_table.columns, f"The PRS effect sizes table must contain a column named {col}." 37 | 38 | validation_beta = [] 39 | prs_beta = [] 40 | ld_weighted_beta = [] 41 | 42 | if 'BETA' in prs_beta_table.columns: 43 | beta_cols = ['BETA'] 44 | else: 45 | beta_cols = [col for col in prs_beta_table.columns 46 | if 'BETA' in col and 'VAR' not in col] 47 | assert len(beta_cols) > 0, ("The PRS effect sizes table must contain " 48 | "a column named BETA or BETA_0, BETA_1, etc.") 49 | 50 | per_chrom_prs_tables = dict(tuple(prs_beta_table.groupby('CHR'))) 51 | 52 | for chrom, tab in validation_tab.items(): 53 | 54 | if chrom not in per_chrom_prs_tables: 55 | continue 56 | 57 | c_df = merge_snp_tables(tab, 58 | per_chrom_prs_tables[chrom], 59 | how='left', 60 | signed_statistics=beta_cols) 61 | 62 | validation_beta.append(tab['STD_BETA'].values) 63 | prs_beta.append(c_df[beta_cols].fillna(0.).values) 64 | ld_weighted_beta.append(test_gdl.ld[chrom].dot(prs_beta[-1])) 65 | 66 | test_gdl.ld[chrom].release() 67 | 68 | return (np.concatenate(validation_beta), 69 | np.concatenate(prs_beta, axis=0), 70 | np.concatenate(ld_weighted_beta)) 71 | 72 | 73 | def pseudo_r2(test_gdl, prs_beta_table): 74 | """ 75 | Compute the R-Squared metric (proportion of variance explained) for a given 76 | PRS using standardized marginal betas from an independent test set. 77 | Here, we follow the pseudo-validation procedures outlined in Mak et al. (2017) and 78 | Yang and Zhou (2020), where the proportion of phenotypic variance explained by the PRS 79 | in an independent validation cohort can be approximated with: 80 | 81 | R2(PRS, y) ~= 2*r'b - b'Sb 82 | 83 | Where `r` is the standardized marginal beta from a validation/test set, 84 | `b` is the posterior mean for the effect size of each variant and `S` is the LD matrix. 85 | 86 | :param test_gdl: An instance of `GWADataLoader` with the summary statistics table initialized. 87 | :param prs_beta_table: A pandas DataFrame with the PRS effect sizes. Must contain 88 | the columns: CHR, SNP, A1, A2, BETA. 89 | """ 90 | 91 | # std_beta, prs_beta, q = _match_variant_stats(test_gdl, prs_beta_table) 92 | 93 | # rb = np.sum((prs_beta.T * std_beta).T, axis=0) 94 | # bsb = np.sum(prs_beta*q, axis=0) 95 | 96 | # return 2*rb - bsb 97 | 98 | # NOTE: The above procedure can be biased/problematic when the LD matrix is highly 99 | # sparsified. For now, we will use the squared Pearson correlation as a proxy for the R^2 metric: 100 | 101 | return pseudo_pearson_r(test_gdl, prs_beta_table)**2 102 | 103 | 104 | def pseudo_pearson_r(test_gdl, prs_beta_table): 105 | """ 106 | Perform pseudo-validation of the inferred effect sizes by comparing them to 107 | standardized marginal betas from an independent validation set. Here, we follow the pseudo-validation 108 | procedures outlined in Mak et al. (2017) and Yang and Zhou (2020), where 109 | the correlation between the PRS and the phenotype in an independent validation 110 | cohort can be approximated with: 111 | 112 | Corr(PRS, y) ~= r'b / sqrt(b'Sb) 113 | 114 | Where `r` is the standardized marginal beta from a validation set, 115 | `b` is the posterior mean for the effect size of each variant and `S` is the LD matrix. 116 | 117 | :param test_gdl: An instance of `GWADataLoader` with the summary statistics table initialized. 118 | :param prs_beta_table: A pandas DataFrame with the PRS effect sizes. Must contain 119 | the columns: CHR, SNP, A1, A2, BETA. 120 | """ 121 | 122 | std_beta, prs_beta, q = _match_variant_stats(test_gdl, prs_beta_table) 123 | 124 | rb = np.sum((prs_beta.T * std_beta).T, axis=0) 125 | bsb = np.sum(prs_beta * q, axis=0) 126 | 127 | return rb / np.sqrt(bsb) 128 | 129 | 130 | def _streamlined_pseudo_r2(validation_beta, prs_beta, ldw_prs_beta): 131 | """ 132 | This function implements a streamlined version of the pseudo-R^2 computation 133 | where we assume that the LD matrix is shared between the training and validation set, 134 | and thus we don't need to recompute the LD-weighted PRS effect sizes. 135 | 136 | This function also assumes that the validation_beta and prs_beta arrays are already 137 | standardized and matched to each other. 138 | 139 | This is useful for cross-validation purposes in the context of PRS analysis. 140 | 141 | :param validation_beta: A numpy array of standardized marginal betas from the validation set. 142 | :param prs_beta: A numpy array/matrix of inferred PRS effect sizes. 143 | :param ldw_prs_beta: The LD-weighted PRS effect sizes from a fitted PRS model (i.e. the result of 144 | multiplying the PRS effect sizes by the LD matrix). 145 | 146 | :return: The pseudo-R^2 metric(s). 147 | """ 148 | 149 | rb = np.sum((prs_beta.T * validation_beta).T, axis=0) 150 | bsb = np.sum(prs_beta * ldw_prs_beta, axis=0) 151 | 152 | return rb**2 / bsb 153 | -------------------------------------------------------------------------------- /viprs/model/vi/e_step_cpp.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | # sources: model/vi/e_step.hpp 3 | 4 | from cython cimport floating 5 | 6 | 7 | cdef extern from "e_step.hpp" nogil: 8 | 9 | bint blas_supported() noexcept nogil 10 | bint omp_supported() noexcept nogil 11 | 12 | void blas_axpy[T](T* y, T* x, T alpha, int size) noexcept nogil 13 | T blas_dot[T](T* x, T* y, int size) noexcept nogil 14 | 15 | void e_step[T, U, I](int c_size, 16 | int* ld_left_bound, 17 | I* ld_indptr, 18 | U* ld_data, 19 | T* std_beta, 20 | T* var_gamma, 21 | T* var_mu, 22 | T* eta, 23 | T* q, 24 | T* eta_diff, 25 | T* u_logs, 26 | T* sqrt_half_var_tau, 27 | T* mu_mult, 28 | T dq_scale, 29 | int threads, 30 | bint low_memory) noexcept nogil 31 | 32 | void e_step_mixture[T, U, I](int c_size, 33 | int K, 34 | int* ld_left_bound, 35 | I* ld_indptr, 36 | U* ld_data, 37 | T* std_beta, 38 | T* var_gamma, 39 | T* var_mu, 40 | T* eta, 41 | T* q, 42 | T* eta_diff, 43 | T* log_null_pi, 44 | T* u_logs, 45 | T* sqrt_half_var_tau, 46 | T* mu_mult, 47 | T dq_scale, 48 | int threads, 49 | bint low_memory) noexcept nogil 50 | 51 | void e_step_grid[T, U, I](int c_size, 52 | int n_active_models, 53 | int* active_model_idx, 54 | int* ld_left_bound, 55 | I* ld_indptr, 56 | U* ld_data, 57 | T* std_beta, 58 | T* var_gamma, 59 | T* var_mu, 60 | T* eta, 61 | T* q, 62 | T* eta_diff, 63 | T* u_logs, 64 | T* half_var_tau, 65 | T* mu_mult, 66 | T dq_scale, 67 | int threads, 68 | bint low_memory) noexcept nogil 69 | 70 | 71 | cpdef check_blas_support(): 72 | return blas_supported() 73 | 74 | 75 | cpdef check_omp_support(): 76 | return omp_supported() 77 | 78 | 79 | cdef void cpp_blas_axpy(floating[::1] v1, floating[::1] v2, floating alpha) noexcept nogil: 80 | """v1 := v1 + alpha * v2""" 81 | cdef int size = v1.shape[0] 82 | blas_axpy(&v1[0], &v2[0], alpha, size) 83 | 84 | 85 | cdef floating cpp_blas_dot(floating[::1] v1, floating[::1] v2) noexcept nogil: 86 | """v1 := v1.Tv2""" 87 | cdef int size = v1.shape[0] 88 | return blas_dot(&v1[0], &v2[0], size) 89 | 90 | 91 | cpdef void cpp_e_step(int[::1] ld_left_bound, 92 | indptr_type[::1] ld_indptr, 93 | noncomplex_numeric[::1] ld_data, 94 | floating[::1] std_beta, 95 | floating[::1] var_gamma, 96 | floating[::1] var_mu, 97 | floating[::1] eta, 98 | floating[::1] q, 99 | floating[::1] eta_diff, 100 | floating[::1] u_logs, 101 | floating[::1] sqrt_half_var_tau, 102 | floating[::1] mu_mult, 103 | floating dq_scale, 104 | int threads, 105 | bint low_memory) noexcept nogil: 106 | 107 | e_step(var_mu.shape[0], 108 | &ld_left_bound[0], 109 | &ld_indptr[0], 110 | &ld_data[0], 111 | &std_beta[0], 112 | &var_gamma[0], 113 | &var_mu[0], 114 | &eta[0], 115 | &q[0], 116 | &eta_diff[0], 117 | &u_logs[0], 118 | &sqrt_half_var_tau[0], 119 | &mu_mult[0], 120 | dq_scale, 121 | threads, 122 | low_memory) 123 | 124 | 125 | cpdef void cpp_e_step_mixture(int[::1] ld_left_bound, 126 | indptr_type[::1] ld_indptr, 127 | noncomplex_numeric[::1] ld_data, 128 | floating[::1] std_beta, 129 | floating[:, ::1] var_gamma, 130 | floating[:, ::1] var_mu, 131 | floating[::1] eta, 132 | floating[::1] q, 133 | floating[::1] eta_diff, 134 | floating[::1] log_null_pi, 135 | floating[:, ::1] u_logs, 136 | floating[:, ::1] sqrt_half_var_tau, 137 | floating[:, ::1] mu_mult, 138 | floating dq_scale, 139 | int threads, 140 | bint low_memory) noexcept nogil: 141 | 142 | e_step_mixture(var_mu.shape[0], 143 | var_mu.shape[1], 144 | &ld_left_bound[0], 145 | &ld_indptr[0], 146 | &ld_data[0], 147 | &std_beta[0], 148 | &var_gamma[0, 0], 149 | &var_mu[0, 0], 150 | &eta[0], 151 | &q[0], 152 | &eta_diff[0], 153 | &log_null_pi[0], 154 | &u_logs[0, 0], 155 | &sqrt_half_var_tau[0, 0], 156 | &mu_mult[0, 0], 157 | dq_scale, 158 | threads, 159 | low_memory) 160 | 161 | cpdef void cpp_e_step_grid(int[::1] ld_left_bound, 162 | indptr_type[::1] ld_indptr, 163 | noncomplex_numeric[::1] ld_data, 164 | floating[::1] std_beta, 165 | floating[::1, :] var_gamma, 166 | floating[::1, :] var_mu, 167 | floating[::1, :] eta, 168 | floating[::1, :] q, 169 | floating[::1, :] eta_diff, 170 | floating[::1, :] u_logs, 171 | floating[::1, :] half_var_tau, 172 | floating[::1, :] mu_mult, 173 | floating dq_scale, 174 | int[:] active_model_idx, 175 | int threads, 176 | bint low_memory) noexcept nogil: 177 | 178 | e_step_grid(var_mu.shape[0], 179 | active_model_idx.shape[0], 180 | &active_model_idx[0], 181 | &ld_left_bound[0], 182 | &ld_indptr[0], 183 | &ld_data[0], 184 | &std_beta[0], 185 | &var_gamma[0, 0], 186 | &var_mu[0, 0], 187 | &eta[0, 0], 188 | &q[0, 0], 189 | &eta_diff[0, 0], 190 | &u_logs[0, 0], 191 | &half_var_tau[0, 0], 192 | &mu_mult[0, 0], 193 | dq_scale, 194 | threads, 195 | low_memory) 196 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [0.1.3] - 2025-04-22 9 | 10 | ### Changed 11 | 12 | - Fixed bugs in `VIPRSGridSearch` and `VIPRSBMA` models, specifically how they were handling `_log_var_tau`, 13 | and the hyperparameters objects after selecting best models or performing model averaging. 14 | - Fixed bug in how `viprs_fit` handles validation `gdl`s when the user passes genotype data. 15 | - Updated interfaces in `HyperparameterSearch` script to make it more flexible and efficient. Primarily, 16 | I added shared memory object for the LD matrix to avoid redundant memory usage when fitting multiple 17 | models in parallel. (** WORK IN PROGRESS **). 18 | - Updated implementation of `pseudo_r2` to use square of pseudo correlation coefficient instead. The previous 19 | implementation can be problematic with highly sparsified LD matrices. 20 | - Updated implementation of `VIPRSGrid` to be better integrated with the `VIPRS` class. The new implementation 21 | also allows for fitting the grid in a `pathwise` fashion (now default behavior), where we use 22 | parameter estimates from previous grid points as warm-start initialization for the current grid point. 23 | - Removed `VIPRSGridSearch` and `VIPRSBMA` classes for now. These functions are implemented in `grid_utils.py` instead 24 | and they can be applied generically to any `VIPRSGrid` model. 25 | 26 | ### Added 27 | 28 | - Added `viprs-cli-example.ipynb` notebook to demonstrate how to use the `viprs` commandline interface. 29 | - Added documentation page for Downloading LD matrices. 30 | - Added new utility function `combine_coefficient_tables` to combine the output from multiple VIPRS models. 31 | - Added more thorough tests for the various models + CLI scripts. 32 | - Added `PeakMemoryProfiler` to `viprs_fit` to more accurately track peak memory usage. Temporary solution, 33 | this will be moved to `magenpy` later on. 34 | - Added support for splitting GWAS sumstats to training/validation sets and exposed appropriate interfaces 35 | in the base class `BayesPRSModel`. 36 | - Added `IterationConditionCounter` class to keep track of the number of consecutive iterations 37 | where a certain condition is met. This is used to monitors convergence of the optimization routine. 38 | 39 | ## [0.1.2] - 2024-12-25 40 | 41 | ### Changed 42 | 43 | - Fixed bug in implementation of `.fit` method of VIPRS models. Specifically, 44 | there was an issue with the `continued=True` flag not working because the `OptimizeResult` 45 | object wasn't refreshed. 46 | - Replaced `print` statements with `logging` where appropriate (still needs some more work). 47 | - Updated way we measure peak memory in `viprs_fit` 48 | - Updated `dict_concat` to just return the element if there's a single entry. 49 | - Refactored pars of `VIPRS` to cache some recurring computations. 50 | - Updated `VIPRSBMA` & `VIPRSGridSearch` to only consider models that 51 | successfully converged. 52 | - Fixed bug in `psuedo_metrics` when extracting summary statistics data. 53 | - Streamlined evaluation code. 54 | - Refactored code to slightly reduce import/load time. 55 | - Fixed bug in `viprs_evaluate` 56 | 57 | ### Added 58 | 59 | - Added SNP position to output table from VIPRS objects. 60 | - Added measure of time taken to prepare data in `viprs_fit`. 61 | - Added option to keep long-range LD regions in `viprs_fit`. 62 | - Added convergence check based on parameter values. 63 | - Added `min_iter` parameter to `.fit` methods to ensure CAVI is run for at least `min_iter` iterations. 64 | - Added separate method for initializing optimization-related objects. 65 | - Added regularization penalty `lambda_min`. 66 | - Added Spearman R and residualized R-Squared metrics to continuous metrics. 67 | 68 | ## [0.1.1] - 2024-04-24 69 | 70 | ### Changed 71 | 72 | - Fixed bugs in the E-Step benchmarking script. 73 | - Re-wrote the logic for finding BLAS libraries in the `setup.py` script. :crossed_fingers: 74 | - Fixed bugs in CI / GitHub Actions scripts. 75 | 76 | ### Added 77 | 78 | - `Dockerfile`s for both `cli` and `jupyter` modes. 79 | 80 | ## [0.1.0] - 2024-04-05 81 | 82 | A large scale restructuring of the code base to improve efficiency and usability. 83 | 84 | ### Changed 85 | 86 | - Moved plotting script to its own separate module. 87 | - Updated some method names / commandline flags to be consistent throughout. 88 | - Updated the `VIPRS` class to allow for more flexibility in the optimization process. 89 | - Removed the `VIPRSAlpha` model for now. This will be re-implemented in the future, 90 | using better interfaces / data structures. 91 | - Moved all hyperparameter search classes/models to their own directory. 92 | - Restructured the `viprs_fit` commandline script to make the code cleaner, 93 | do better sanity checking, and introduce process parallelism over chromosomes. 94 | 95 | ### Added 96 | 97 | - Basic integration testing with `pytest` and GitHub workflows. 98 | - Documentation for the entire package using `mkdocs`. 99 | - Integration testing / automating building with GitHub workflows. 100 | - New self-contained implementation of E-Step in `Cython` and `C++`. 101 | - Uses `OpenMP` for parallelism across chunks of variants. 102 | - Allows for de-quantization on the fly of the LD matrix. 103 | - Uses BLAS linear algebra operations where possible. 104 | - Allows model fitting with only 105 | - Benchmarking scripts (`benchmark_e_step.py`) to compare computational performance of different implementations. 106 | - Added functionality to allow the user to track time / memory utilization in `viprs_fit`. 107 | - Added `OptimizeResult` class to keep track of the info/parameters of EM optimization. 108 | - New evaluation metrics 109 | - `pseudo_metrics` has been moved to its own module to allow for more flexibility in evaluation. 110 | - New evaluation metrics for binary traits: `nagelkerke_r2`, `mcfadden_r2`, 111 | `cox_snell_r2` `liability_r2`, `liability_probit_r2`, `liability_logit_r2`. 112 | - New function to compute standard errors / test statistics for all R-Squared metrics. 113 | 114 | ## [0.0.4] - 2022-09-07 115 | 116 | ### Changed 117 | 118 | - Removed the `--fast-math` compiler flag due to concerns about 119 | numerical precision (e.g. [Beware of fast-math](https://simonbyrne.github.io/notes/fastmath/)). 120 | 121 | ## [0.0.3] - 2022-09-06 122 | 123 | ### Added 124 | 125 | - New implementation for the e-step in `VIPRS`, where we multiply with the rows of the 126 | LD matrix only once. 127 | - Added support for deterministic annealing in the `VIPRS` optimization. 128 | - Added support for `pseudo_validation` as a metric for choosing models. Now, the 129 | `VIPRS` class has a method called `pseudo_validate`. 130 | - New implementations for grid-based models: `VIPRSGrid`, `VIPRSGridSearch`, `VIPRSBMA`. 131 | - New python implementation of the `LDPredinf` model, using the `viprs`/`magenpy` 132 | data structures. 133 | - MIT license for the software. 134 | 135 | ### Changed 136 | 137 | - Corrected implementation of Mean Squared Error (MSE) metric. 138 | - Changed the `c_utils.pyx` script to be `math_utils.pyx`. 139 | - Updated documentation in `README` to follow latest APIs. 140 | 141 | ## [0.0.2] - 2022-06-28 142 | 143 | ### Changed 144 | 145 | - Updating the dependency structure between `viprs` and `magenpy`. 146 | 147 | ## [0.0.1] - 2022-06-28 148 | 149 | ### Added 150 | 151 | - Refactoring the code in the `viprs` repository and re-organizing it into a python package. 152 | - Added a module to compute predictive performance metrics. 153 | - Added commandline scripts to allow users to access some of the functionalities of `viprs` without 154 | necessarily having to write python code. 155 | - Added the estimate of the posterior variance to the output from the module. 156 | 157 | ### Changed 158 | 159 | - Updated plotting script. 160 | - Updated implementation of `VIPRSMix`, `VIPRSAalpha`, etc. to inherit most 161 | of their functionalities from the base `VIPRS` class. 162 | - Cleaned up implementation of hyperparameter search modules. 163 | 164 | -------------------------------------------------------------------------------- /viprs/model/gridsearch/grid_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Set up the logger: 4 | import logging 5 | logger = logging.getLogger(__name__) 6 | 7 | 8 | def select_best_model(viprs_grid_model, validation_gdl=None, criterion='ELBO'): 9 | """ 10 | From the grid of models that were fit to the data, select the best 11 | model according to the specified `criterion`. If the criterion is the ELBO, 12 | the model with the highest ELBO will be selected. If the criterion is 13 | validation or pseudo-validation, the model with the highest R^2 on the 14 | held-out validation set will be selected. 15 | 16 | :param viprs_grid_model: An instance of `VIPRSGrid` or `VIPRSGridPathwise` containing the fitted grid 17 | of VIPRS models. 18 | :param validation_gdl: An instance of `GWADataLoader` containing data from the validation set. 19 | :param criterion: The criterion for selecting the best model. 20 | Options are: (`ELBO`, `validation`, `pseudo_validation`) 21 | """ 22 | 23 | assert criterion in ('ELBO', 'validation', 'pseudo_validation') 24 | 25 | if criterion == 'validation': 26 | assert validation_gdl is not None, "Validation GWADataLoader must be provided for validation criterion." 27 | elif criterion == 'pseudo_validation' and validation_gdl is None and viprs_grid_model.validation_std_beta is None: 28 | raise ValueError("Validation GWADataLoader or standardized betas from a validation set must be " 29 | "initialized for the pseudo_validation criterion.") 30 | 31 | # Extract the models that converged successfully: 32 | models_converged = viprs_grid_model.valid_terminated_models 33 | best_model_idx = None 34 | 35 | if np.sum(models_converged) < 2: 36 | raise ValueError("Less than two models converged successfully. Cannot perform model selection.") 37 | else: 38 | 39 | if criterion == 'ELBO': 40 | elbo = viprs_grid_model.elbo() 41 | elbo[~models_converged] = -np.inf 42 | best_model_idx = np.argmax(elbo) 43 | elif criterion == 'validation': 44 | 45 | assert validation_gdl is not None 46 | assert validation_gdl.sample_table is not None 47 | assert validation_gdl.sample_table.phenotype is not None 48 | 49 | from viprs.eval.continuous_metrics import r2 50 | 51 | prs = viprs_grid_model.predict(test_gdl=validation_gdl) 52 | prs_r2 = np.array([r2(prs[:, i], validation_gdl.sample_table.phenotype) 53 | for i in range(viprs_grid_model.n_models)]) 54 | prs_r2[~models_converged] = -np.inf 55 | viprs_grid_model.validation_result['Validation_R2'] = prs_r2 56 | best_model_idx = np.argmax(prs_r2) 57 | elif criterion == 'pseudo_validation': 58 | 59 | pseudo_r2 = viprs_grid_model.pseudo_validate(validation_gdl) 60 | pseudo_r2[~models_converged] = -np.inf 61 | viprs_grid_model.validation_result['Pseudo_Validation_R2'] = pseudo_r2 62 | best_model_idx = np.argmax(np.nan_to_num(pseudo_r2, nan=0., neginf=0., posinf=0.)) 63 | 64 | logger.info(f"> Based on the {criterion} criterion, selected model: {best_model_idx}") 65 | logger.info("> Model details:\n") 66 | logger.info(viprs_grid_model.validation_result.iloc[best_model_idx, :]) 67 | 68 | # ----------------------------------------------------------------------- 69 | # Update the variational parameters and their dependencies to only select the best model: 70 | for param in (viprs_grid_model.pip, viprs_grid_model.post_mean_beta, viprs_grid_model.post_var_beta, 71 | viprs_grid_model.var_gamma, viprs_grid_model.var_mu, viprs_grid_model.var_tau, 72 | viprs_grid_model.eta, viprs_grid_model.zeta, viprs_grid_model.q, 73 | viprs_grid_model._log_var_tau): 74 | for c in param: 75 | param[c] = param[c][:, best_model_idx] 76 | 77 | # Update the eta diff: 78 | try: 79 | for c in viprs_grid_model.eta_diff: 80 | viprs_grid_model.eta_diff[c] = viprs_grid_model.eta_diff[c][:, best_model_idx] 81 | except IndexError: 82 | # Don't need to update this for the VIPRSGridPathwise model. 83 | pass 84 | 85 | # Update sigma_epsilon: 86 | viprs_grid_model.sigma_epsilon = viprs_grid_model.sigma_epsilon[best_model_idx] 87 | 88 | # Update sigma_g: 89 | viprs_grid_model._sigma_g = viprs_grid_model._sigma_g[best_model_idx] 90 | 91 | # Update sigma beta: 92 | if isinstance(viprs_grid_model.tau_beta, dict): 93 | for c in viprs_grid_model.tau_beta: 94 | viprs_grid_model.tau_beta[c] = viprs_grid_model.tau_beta[c][:, best_model_idx] 95 | else: 96 | viprs_grid_model.tau_beta = viprs_grid_model.tau_beta[best_model_idx] 97 | 98 | # Update pi 99 | 100 | if isinstance(viprs_grid_model.pi, dict): 101 | for c in viprs_grid_model.pi: 102 | viprs_grid_model.pi[c] = viprs_grid_model.pi[c][:, best_model_idx] 103 | else: 104 | viprs_grid_model.pi = viprs_grid_model.pi[best_model_idx] 105 | 106 | # ----------------------------------------------------------------------- 107 | 108 | # Set the number of models to 1: 109 | viprs_grid_model.n_models = 1 110 | 111 | # Update the fixed parameters of the model: 112 | viprs_grid_model.set_fixed_params( 113 | viprs_grid_model.grid_table.iloc[best_model_idx].to_dict() 114 | ) 115 | 116 | # ----------------------------------------------------------------------- 117 | 118 | return viprs_grid_model 119 | 120 | 121 | def bayesian_model_average(viprs_grid_model, normalization='softmax'): 122 | """ 123 | Use Bayesian model averaging (BMA) to obtain a weighing scheme for the 124 | variational parameters of a grid of VIPRS models. The parameters of each model in the grid 125 | are assigned weights proportional to their final ELBO. 126 | 127 | :param viprs_grid_model: An instance of `VIPRSGrid` or `VIPRSGridPathwise` containing the fitted grid 128 | of VIPRS models. 129 | :param normalization: The normalization scheme for the final ELBOs. 130 | Options are (`softmax`, `sum`). 131 | :raises KeyError: If the normalization scheme is not recognized. 132 | """ 133 | 134 | if viprs_grid_model.n_models < 2: 135 | return viprs_grid_model 136 | 137 | if np.sum(viprs_grid_model.valid_terminated_models) < 1: 138 | raise ValueError("No models converged successfully. " 139 | "Cannot average models.") 140 | 141 | # Extract the models that converged successfully: 142 | models_to_keep = np.where(viprs_grid_model.valid_terminated_models)[0] 143 | 144 | elbos = viprs_grid_model.elbo() 145 | 146 | if normalization == 'softmax': 147 | from scipy.special import softmax 148 | weights = np.array(softmax(elbos)) 149 | elif normalization == 'sum': 150 | weights = np.array(elbos) 151 | 152 | # Correction for negative ELBOs: 153 | weights = weights - weights.min() + 1. 154 | weights /= weights.sum() 155 | else: 156 | raise KeyError("Normalization scheme not recognized. " 157 | "Valid options are: `softmax`, `sum`. " 158 | "Got: {}".format(normalization)) 159 | 160 | logger.info("Averaging PRS models with weights:", weights) 161 | 162 | # Average the model parameters: 163 | for param in (viprs_grid_model.var_gamma, viprs_grid_model.var_mu, viprs_grid_model.var_tau, 164 | viprs_grid_model.q): 165 | for c in param: 166 | param[c] = (param[c][:, models_to_keep] * weights).sum(axis=1) 167 | 168 | viprs_grid_model.eta = viprs_grid_model.compute_eta() 169 | viprs_grid_model.zeta = viprs_grid_model.compute_zeta() 170 | 171 | # Update posterior moments: 172 | viprs_grid_model.update_posterior_moments() 173 | 174 | # Update the log of the variational tau parameters: 175 | viprs_grid_model._log_var_tau = {c: np.log(viprs_grid_model.var_tau[c]) 176 | for c in viprs_grid_model.var_tau} 177 | 178 | # Update the hyperparameters based on the averaged weights 179 | import copy 180 | # TODO: double check to make sure this makes sense. 181 | fix_params_before = copy.deepcopy(viprs_grid_model.fix_params) 182 | viprs_grid_model.fix_params = {} 183 | viprs_grid_model.m_step() 184 | viprs_grid_model.fix_params = fix_params_before 185 | 186 | # ----------------------------------------------------------------------- 187 | 188 | # Set the number of models to 1: 189 | viprs_grid_model.n_models = 1 190 | 191 | # ----------------------------------------------------------------------- 192 | 193 | return viprs_grid_model 194 | -------------------------------------------------------------------------------- /bin/viprs_evaluate: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Evaluate the predictive performance of PRS models 5 | ---------------------------- 6 | 7 | This is a commandline script that can compute various metrics to evaluate the predictive performance of 8 | polygenic risk score (PRS) models. The script can compute metrics for both continuous and binary phenotypes. 9 | 10 | The script requires two input files: 11 | 12 | - `--prs-file`: The path to the PRS file. The file should have the following 13 | format: FID IID PRS, where FID and IID are the family and individual IDs, and PRS is the polygenic risk score. 14 | - `--phenotype-file`: The path to the phenotype file. The file should have the following format: FID IID phenotype, 15 | where FID and IID are the family and individual IDs, and phenotype is the phenotype value. 16 | 17 | Usage: 18 | 19 | python -m viprs_evaluate --prs-file /path/to/prs_file 20 | --phenotype-file /path/to/phenotype_file 21 | --output-file /path/to/output_file 22 | 23 | """ 24 | 25 | # Setup the logger: 26 | import logging 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | def main(): 31 | 32 | import argparse 33 | import viprs as vp 34 | 35 | print("\n" + vp.make_ascii_logo( 36 | desc='< Evaluate Prediction Accuracy of PRS Models >', 37 | left_padding=10 38 | ) + "\n", flush=True) 39 | 40 | parser = argparse.ArgumentParser(description=""" 41 | Commandline arguments for evaluating polygenic scores 42 | """) 43 | 44 | parser.add_argument('--prs-file', dest='prs_file', type=str, required=True, 45 | help='The path to the PRS file (expected format: FID IID PRS, tab-separated)') 46 | parser.add_argument('--phenotype-file', dest='pheno_file', type=str, required=True, 47 | help='The path to the phenotype file. ' 48 | 'The expected format is: FID IID phenotype (no header), tab-separated.') 49 | parser.add_argument('--phenotype-col', dest='pheno_col', type=int, default=2, 50 | help='The column index for the phenotype in the phenotype file (0-based index).') 51 | parser.add_argument('--phenotype-likelihood', dest='pheno_lik', type=str, default='infer', 52 | choices={'gaussian', 'binomial', 'infer'}, 53 | help='The phenotype likelihood ("gaussian" for continuous, "binomial" for case-control). ' 54 | 'If not set, will be inferred automatically based on the phenotype file.') 55 | parser.add_argument('--keep', dest='keep', type=str, 56 | help='A plink-style keep file to select a subset of individuals for the evaluation.') 57 | parser.add_argument('--output-file', dest='output_file', type=str, required=True, 58 | help='The output file where to store the evaluation metrics (with no extension).') 59 | parser.add_argument('--metrics', dest='metrics', type=str, nargs='+', 60 | help='The evaluation metrics to compute (default: all available metrics that are ' 61 | 'relevant for the phenotype). For a full list of supported metrics, ' 62 | 'check the documentation.') 63 | parser.add_argument('--covariates-file', dest='covariates_file', type=str, 64 | help='A file with covariates for the samples included in the analysis. This tab-separated ' 65 | 'file should not have a header and the first two columns should be ' 66 | 'the FID and IID of the samples.') 67 | parser.add_argument('--log-level', dest='log_level', type=str, default='WARNING', 68 | choices={'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'}, 69 | help='The logging level for the console output.') 70 | 71 | args = parser.parse_args() 72 | 73 | # ---------------------------------------------------------- 74 | import os.path as osp 75 | import pandas as pd 76 | from magenpy.utils.system_utils import makedir, setup_logger 77 | from magenpy import SampleTable 78 | from viprs.eval import eval_metric_names, eval_incremental_metrics 79 | from viprs.eval.eval_utils import r2_stats 80 | 81 | # ---------------------------------------------------------- 82 | # Setup the logger: 83 | 84 | # Create the output directory: 85 | makedir(osp.dirname(args.output_file)) 86 | 87 | # Clear the log file: 88 | log_file = f"{args.output_file}.log" 89 | open(log_file, 'w').close() 90 | 91 | # Set up the module loggers: 92 | setup_logger(modules=['viprs', 'magenpy'], 93 | log_file=log_file, 94 | log_level=args.log_level) 95 | 96 | # Set up the logger for the main module: 97 | setup_logger(loggers=[logger], 98 | log_file=log_file, 99 | log_format='%(message)s', 100 | log_level=['INFO', args.log_level][logging.getLevelName(args.log_level) < logging.INFO]) 101 | 102 | # ---------------------------------------------------------- 103 | logger.info('{:-^100}\n'.format(' Parsed arguments ')) 104 | 105 | for key, val in vars(args).items(): 106 | if val is not None and val != parser.get_default(key): 107 | logger.info(f"-- {key}: {val}") 108 | 109 | # ---------------------------------------------------------- 110 | logger.info('\n{:-^100}\n'.format(' Reading input data ')) 111 | 112 | sample_table = SampleTable(phenotype_likelihood=args.pheno_lik) 113 | 114 | # Read the phenotype file: 115 | sample_table.read_phenotype_file(args.pheno_file, usecols=[0, 1, args.pheno_col]) 116 | 117 | assert sample_table.n > 0, logger.error("No samples found in the phenotype file.") 118 | 119 | # Read the covariates file: 120 | if args.covariates_file is not None: 121 | sample_table.read_covariates_file(args.covariates_file) 122 | 123 | if args.keep is not None: 124 | sample_table.filter_samples(keep_file=args.keep) 125 | 126 | # Make sure that samples remain after reading both: 127 | assert sample_table.n > 0, logger.error("No samples found after merging the covariates and phenotype files.") 128 | 129 | prs_df = pd.read_csv(args.prs_file, sep=r'\s+') 130 | 131 | # Merge the PRS data with the phenotype data: 132 | prs_df = sample_table.get_individual_table().merge(prs_df, on=['FID', 'IID']) 133 | 134 | assert len(prs_df) > 0, logger.error("No common samples found in the PRS and phenotype files.") 135 | 136 | sample_table.filter_samples(keep_samples=prs_df.IID.values) 137 | 138 | # ---------------------------------------------------------- 139 | logger.info('\n{:-^100}\n'.format(' Evaluating PRS performance ')) 140 | 141 | if sample_table.phenotype_likelihood == 'binomial': 142 | metrics = args.metrics or ['AUROC', 'AUPRC', 'Nagelkerke_R2', 'Liability_R2'] 143 | else: 144 | metrics = args.metrics or ['Pearson_R', 'R2', 145 | 'Incremental_R2', 'Partial_Correlation'] 146 | 147 | if isinstance(metrics, str): 148 | metrics = [metrics] 149 | 150 | # Loop over the requested metrics and evaluate them, and store result in a dictionary: 151 | 152 | info_dict = {'Sample size': sample_table.n} 153 | 154 | if args.covariates_file is not None: 155 | covariates = sample_table.get_covariates_table().drop(columns=['FID', 'IID']) 156 | else: 157 | covariates = None 158 | 159 | for metric in metrics: 160 | 161 | # If covariates are provided and the metric can be computed 162 | # while adjusting for covariates, then do so: 163 | if metric in eval_incremental_metrics and covariates is not None: 164 | info_dict[metric] = eval_metric_names[metric](sample_table.phenotype, 165 | prs_df['PRS'].values, 166 | covariates) 167 | 168 | else: 169 | info_dict[metric] = eval_metric_names[metric](sample_table.phenotype, 170 | prs_df['PRS'].values) 171 | 172 | # Compute the standard errors for R-squared metrics: 173 | if 'R2' in metric: 174 | info_dict[f'{metric}_err'] = r2_stats(info_dict[metric], sample_table.n)['SE'] 175 | 176 | # ---------------------------------------------------------- 177 | 178 | logger.info(f"\n>>> Writing the evaluation metrics to:\n {osp.dirname(args.output_file)}") 179 | 180 | makedir(osp.dirname(args.output_file)) 181 | pd.DataFrame([info_dict]).to_csv(args.output_file + ".eval", sep="\t", index=False) 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | 2 | import magenpy as mgp 3 | import viprs as vp 4 | from viprs.model.vi.e_step_cpp import check_blas_support, check_omp_support 5 | import numpy as np 6 | from viprs.model import VIPRS, VIPRSMix, VIPRSGrid 7 | from viprs.model.gridsearch import ( 8 | HyperparameterGrid, 9 | select_best_model, 10 | bayesian_model_average 11 | ) 12 | from functools import partial 13 | import shutil 14 | import pytest 15 | 16 | 17 | @pytest.fixture(scope='module') 18 | def gdl_object(): 19 | """ 20 | Initialize a GWADataLoader using data pre-packaged with magenpy. 21 | Make this data loader available to all tests. 22 | """ 23 | gdl = mgp.GWADataLoader(mgp.tgp_eur_data_path(), 24 | sumstats_files=mgp.ukb_height_sumstats_path(), 25 | sumstats_format='fastgwa', 26 | backend='xarray') 27 | 28 | ld_block_url = ("https://bitbucket.org/nygcresearch/ldetect-data/raw/" 29 | "ac125e47bf7ff3e90be31f278a7b6a61daaba0dc/EUR/fourier_ls-all.bed") 30 | gdl.compute_ld('block', gdl.output_dir, ld_blocks_file=ld_block_url) 31 | 32 | gdl.harmonize_data() 33 | 34 | yield gdl 35 | 36 | # Clean up after tests are done: 37 | gdl.cleanup() 38 | shutil.rmtree(gdl.temp_dir) 39 | shutil.rmtree(gdl.output_dir) 40 | 41 | 42 | @pytest.fixture(scope='module') 43 | def viprs_model(gdl_object): 44 | """ 45 | Initialize a basic VIPRS model using GWAS sumstats data pre-packaged with magenpy. 46 | Make this data loader available to all tests. 47 | """ 48 | return vp.VIPRS(gdl_object) 49 | 50 | 51 | @pytest.fixture(scope='module') 52 | def viprsmix_model(gdl_object): 53 | """ 54 | Initialize a VIPRS model (Mixture prior) using GWAS sumstats data pre-packaged with magenpy. 55 | Make this data loader available to all tests. 56 | """ 57 | return VIPRSMix(gdl_object, K=10) 58 | 59 | 60 | @pytest.fixture(scope='module') 61 | def grid_obj(): 62 | """ 63 | Initialize a grid object. 64 | """ 65 | grid = HyperparameterGrid() 66 | grid.generate_pi_grid(steps=10) 67 | return grid 68 | 69 | 70 | @pytest.fixture(scope='module') 71 | def viprs_grid_model(gdl_object, grid_obj): 72 | """ 73 | Initialize a VIPRSGrid model using GWAS sumstats data pre-packaged with magenpy, 74 | as well as a grid object. 75 | """ 76 | 77 | return VIPRSGrid(gdl_object, grid_obj) 78 | 79 | 80 | class TestVIPRS(object): 81 | 82 | def test_init(self, 83 | viprs_model: VIPRS, 84 | gdl_object: mgp.GWADataLoader): 85 | 86 | assert viprs_model.m == gdl_object.m 87 | 88 | viprs_model.initialize() 89 | 90 | # Check the input data: 91 | for p in (viprs_model.std_beta, viprs_model.n_per_snp): 92 | assert p[22].shape == (viprs_model.m, ) 93 | 94 | # Check the LD data: 95 | assert viprs_model.ld_indptr[22].shape == (viprs_model.m + 1, ) 96 | assert viprs_model.ld_left_bound[22].shape == (viprs_model.m, ) 97 | assert viprs_model.ld_data[22].shape == (viprs_model.ld_indptr[22][-1], ) 98 | 99 | # Check hyperparameters: 100 | assert 0. < viprs_model.pi < 1. 101 | assert 0. < viprs_model.sigma_epsilon < 1. 102 | assert viprs_model.tau_beta > 0. 103 | 104 | # Check the model parameters: 105 | for p in (viprs_model.var_gamma, viprs_model.var_mu, viprs_model.var_tau, 106 | viprs_model.q, viprs_model.eta): 107 | assert p[22].shape == (viprs_model.m, ) 108 | 109 | # Other checks here? 110 | 111 | def test_fit(self, viprs_model: VIPRS): 112 | 113 | viprs_model.fit(max_iter=10) 114 | 115 | # Check the posterior moments: 116 | for p in (viprs_model.pip, viprs_model.post_mean_beta, viprs_model.post_var_beta): 117 | assert p[22].shape == (viprs_model.m,) 118 | 119 | # Test that the following methods are working properly: 120 | viprs_model.to_table() 121 | viprs_model.to_theta_table() 122 | viprs_model.to_history_table() 123 | viprs_model.mse() 124 | viprs_model.log_prior() 125 | viprs_model.loglikelihood() 126 | viprs_model.entropy() 127 | 128 | 129 | class TestVIPRSMix(TestVIPRS): 130 | 131 | def test_init(self, 132 | viprsmix_model: VIPRSMix, 133 | gdl_object: mgp.GWADataLoader): 134 | 135 | assert viprsmix_model.m == gdl_object.m 136 | 137 | viprsmix_model.initialize() 138 | 139 | # Check the input data: 140 | assert viprsmix_model.std_beta[22].shape == (viprsmix_model.m,) 141 | assert viprsmix_model.n_per_snp[22].shape == (viprsmix_model.m, 1) 142 | 143 | # Check the LD data: 144 | assert viprsmix_model.ld_indptr[22].shape == (viprsmix_model.m + 1,) 145 | assert viprsmix_model.ld_left_bound[22].shape == (viprsmix_model.m,) 146 | assert viprsmix_model.ld_data[22].shape == (viprsmix_model.ld_indptr[22][-1],) 147 | 148 | # Check the hyperparameters: 149 | assert np.all((0. < viprsmix_model.pi) & (viprsmix_model.pi < 1.)) 150 | assert 0. < np.sum(viprsmix_model.pi) < 1. 151 | assert 0. < viprsmix_model.sigma_epsilon < 1. 152 | assert np.all(viprsmix_model.tau_beta > 0.) 153 | 154 | # Check the variational parameters: 155 | for p in (viprsmix_model.var_gamma, viprsmix_model.var_mu, viprsmix_model.var_tau): 156 | assert p[22].shape == viprsmix_model.shapes[22] 157 | 158 | # Check the aggregation parameters: 159 | for p in (viprsmix_model.q, viprsmix_model.eta): 160 | assert p[22].shape == (viprsmix_model.m, ) 161 | 162 | def test_fit(self, viprsmix_model): 163 | 164 | viprsmix_model.fit(max_iter=10) 165 | 166 | for p in (viprsmix_model.var_gamma, viprsmix_model.var_mu, viprsmix_model.var_tau): 167 | assert p[22].shape == viprsmix_model.shapes[22] 168 | 169 | for p in (viprsmix_model.pip, viprsmix_model.post_mean_beta, viprsmix_model.post_var_beta): 170 | assert p[22].shape == (viprsmix_model.m,) 171 | 172 | # Test that the following methods don't fail: 173 | viprsmix_model.to_table() 174 | viprsmix_model.to_theta_table() 175 | viprsmix_model.to_history_table() 176 | viprsmix_model.mse() 177 | viprsmix_model.log_prior() 178 | viprsmix_model.loglikelihood() 179 | viprsmix_model.entropy() 180 | 181 | 182 | class TestVIPRSGrid(TestVIPRS): 183 | 184 | """ 185 | Not testing the initialization because it should be the same as 186 | the standard VIPRS. 187 | """ 188 | 189 | def test_fit(self, 190 | viprs_grid_model: VIPRSGrid): 191 | 192 | # Test splitting the sumstats data (PUMAS): 193 | viprs_grid_model.split_gwas_sumstats() 194 | 195 | # Check the split sumstats: 196 | for p in (viprs_grid_model.std_beta, viprs_grid_model.validation_std_beta, viprs_grid_model.n_per_snp): 197 | assert p[22].shape == (viprs_grid_model.m,) 198 | 199 | # ----------------------------------------------------------------------- 200 | 201 | model_selection_criteria = [ 202 | partial(select_best_model, criterion='ELBO'), 203 | partial(select_best_model, criterion='pseudo_validation'), 204 | bayesian_model_average 205 | ] 206 | 207 | for criterion in model_selection_criteria: 208 | # Reset the search: 209 | viprs_grid_model._reset_search() 210 | 211 | # Perform model fit: 212 | viprs_grid_model.fit(max_iter=10) 213 | 214 | for p in (viprs_grid_model.pip, viprs_grid_model.post_mean_beta, viprs_grid_model.post_var_beta): 215 | assert p[22].shape == (viprs_grid_model.m, viprs_grid_model.n_models) 216 | 217 | # Test that the following methods don't fail: 218 | viprs_grid_model.to_table() 219 | viprs_grid_model.to_theta_table() 220 | viprs_grid_model.to_history_table() 221 | viprs_grid_model.mse() 222 | viprs_grid_model.log_prior() 223 | viprs_grid_model.loglikelihood() 224 | viprs_grid_model.entropy() 225 | viprs_grid_model.pseudo_validate() 226 | 227 | # Perform model selection: 228 | criterion(viprs_grid_model) 229 | 230 | viprs_grid_model.fit(max_iter=10) 231 | 232 | # Check that the other methods work fine still: 233 | for p in (viprs_grid_model.pip, viprs_grid_model.post_mean_beta, viprs_grid_model.post_var_beta): 234 | assert p[22].shape == (viprs_grid_model.m,) 235 | 236 | viprs_grid_model.to_table() 237 | viprs_grid_model.to_theta_table() 238 | viprs_grid_model.to_history_table() 239 | viprs_grid_model.mse() 240 | viprs_grid_model.log_prior() 241 | viprs_grid_model.loglikelihood() 242 | viprs_grid_model.entropy() 243 | viprs_grid_model.pseudo_validate() 244 | 245 | 246 | @pytest.mark.xfail(not check_blas_support(), reason="BLAS library not found!") 247 | def test_check_blas_support(): 248 | assert check_blas_support() 249 | 250 | 251 | @pytest.mark.xfail(not check_omp_support(), reason="OpenMP library not found!") 252 | def test_check_omp_support(): 253 | assert check_omp_support() 254 | -------------------------------------------------------------------------------- /docs/commandline/viprs_fit.md: -------------------------------------------------------------------------------- 1 | Fit VIPRS model to GWAS summary statistics (`viprs_fit`) 2 | --- 3 | 4 | The `viprs_fit` script is used to fit the variational PRS model to the GWAS summary statistics and to estimate the 5 | posterior distribution of the variant effect sizes. The script provides a variety of options for the user to 6 | customize the inference process, including the choice of prior distributions and the choice of 7 | optimization algorithms. 8 | 9 | A full listing of the options available for the `viprs_fit` script can be found by running the following command in your terminal: 10 | 11 | ```bash 12 | viprs_fit -h 13 | ``` 14 | 15 | Which outputs the following help message: 16 | 17 | ```bash 18 | 19 |           ********************************************** 20 |                    _____                           21 |            ___   _____(_)________ ________________ 22 |            __ | / /__  / ___  __ \__  ___/__  ___/ 23 |            __ |/ / _  /  __  /_/ /_  /    _(__  )  24 |            _____/  /_/   _  .___/ /_/     /____/   25 |                          /_/                       26 |                 27 |           Variational Inference of Polygenic Risk Scores 28 |            Version: 0.1.3 | Release date: April 2025 29 |            Author: Shadi Zabad, McGill University 30 |           ********************************************** 31 |            < Fit VIPRS to GWAS summary statistics > 32 | 33 | usage: viprs_fit [-h] -l LD_DIR -s SUMSTATS_PATH --output-dir OUTPUT_DIR [--output-file-prefix OUTPUT_PREFIX] [--temp-dir TEMP_DIR] 34 | [--sumstats-format {gwas-ssf,saige,plink,custom,fastgwa,plink2,cojo,ssf,magenpy,gwascatalog,plink1.9}] 35 | [--custom-sumstats-mapper CUSTOM_SUMSTATS_MAPPER] [--custom-sumstats-sep CUSTOM_SUMSTATS_SEP] [--gwas-sample-size GWAS_SAMPLE_SIZE] 36 | [--validation-bfile VALIDATION_BED] [--validation-pheno VALIDATION_PHENO] [--validation-keep VALIDATION_KEEP] 37 | [--validation-ld-panel VALIDATION_LD_PANEL] [--validation-sumstats VALIDATION_SUMSTATS_PATH] 38 | [--validation-sumstats-format {gwas-ssf,saige,plink,custom,fastgwa,plink2,cojo,ssf,magenpy,gwascatalog,plink1.9}] [-m {VIPRS,VIPRSMix}] 39 | [--float-precision {float32,float64}] [--use-symmetric-ld] [--dequantize-on-the-fly] [--fix-sigma-epsilon FIX_SIGMA_EPSILON] 40 | [--lambda-min LAMBDA_MIN] [--n-components N_COMPONENTS] [--max-iter MAX_ITER] [--h2-est H2_EST] [--h2-se H2_SE] [--hyp-search {GS,BMA,EM}] 41 | [--grid-metric {ELBO,validation,pseudo_validation}] [--grid-search-mode {pathwise,independent}] [--prop-train PROP_TRAIN] [--pi-grid PI_GRID] 42 | [--pi-steps PI_STEPS] [--sigma-epsilon-grid SIGMA_EPSILON_GRID] [--sigma-epsilon-steps SIGMA_EPSILON_STEPS] 43 | [--lambda-min-steps LAMBDA_MIN_STEPS] [--genomewide] [--exclude-lrld] [--backend {plink,xarray}] [--n-jobs N_JOBS] [--threads THREADS] 44 | [--output-profiler-metrics] [--log-level {DEBUG,WARNING,INFO,ERROR,CRITICAL}] [--seed SEED] 45 | 46 | Commandline arguments for fitting VIPRS to GWAS summary statistics 47 | 48 | options: 49 | -h, --help show this help message and exit 50 | -l LD_DIR, --ld-panel LD_DIR 51 | The path to the directory where the LD matrices are stored. Can be a wildcard of the form ld/chr_* 52 | -s SUMSTATS_PATH, --sumstats SUMSTATS_PATH 53 | The summary statistics directory or file. Can be a wildcard of the form sumstats/chr_* 54 | --output-dir OUTPUT_DIR 55 | The output directory where to store the inference results. 56 | --output-file-prefix OUTPUT_PREFIX 57 | A prefix to append to the names of the output files (optional). 58 | --temp-dir TEMP_DIR The temporary directory where to store intermediate files. 59 | --sumstats-format {gwas-ssf,saige,plink,custom,fastgwa,plink2,cojo,ssf,magenpy,gwascatalog,plink1.9} 60 | The format for the summary statistics file(s). 61 | --custom-sumstats-mapper CUSTOM_SUMSTATS_MAPPER 62 | A comma-separated string with column name mappings between the custom summary statistics format and the standard format expected by 63 | magenpy/VIPRS. Provide only mappings for column names that are different, in the form of:--custom-sumstats-mapper 64 | rsid=SNP,eff_allele=A1,beta=BETA 65 | --custom-sumstats-sep CUSTOM_SUMSTATS_SEP 66 | The delimiter for the summary statistics file with custom format. 67 | --gwas-sample-size GWAS_SAMPLE_SIZE 68 | The overall sample size for the GWAS study. This must be provided if the sample size per-SNP is not in the summary statistics file. 69 | --validation-bfile VALIDATION_BED 70 | The BED files containing the genotype data for the validation set. You may use a wildcard here (e.g. "data/chr_*.bed") 71 | --validation-pheno VALIDATION_PHENO 72 | A tab-separated file containing the phenotype for the validation set. The expected format is: FID IID phenotype (no header) 73 | --validation-keep VALIDATION_KEEP 74 | A plink-style keep file to select a subset of individuals for the validation set. 75 | --validation-ld-panel VALIDATION_LD_PANEL 76 | The path to the directory where the LD matrices for the validation set are stored. Can be a wildcard of the form ld/chr_* 77 | --validation-sumstats VALIDATION_SUMSTATS_PATH 78 | The summary statistics directory or file for the validation set. Can be a wildcard of the form sumstats/chr_* 79 | --validation-sumstats-format {gwas-ssf,saige,plink,custom,fastgwa,plink2,cojo,ssf,magenpy,gwascatalog,plink1.9} 80 | The format for the summary statistics file(s) for the validation set. 81 | -m {VIPRS,VIPRSMix}, --model {VIPRS,VIPRSMix} 82 | The type of PRS model to fit to the GWAS data 83 | --float-precision {float32,float64} 84 | The float precision to use when fitting the model. 85 | --use-symmetric-ld Use the symmetric form of the LD matrix when fitting the model. 86 | --dequantize-on-the-fly 87 | Dequantize the entries of the LD matrix on-the-fly during inference. 88 | --fix-sigma-epsilon FIX_SIGMA_EPSILON 89 | Set the value of the residual variance hyperparameter, sigma_epsilon, to the provided value. 90 | --lambda-min LAMBDA_MIN 91 | Set the value of the lambda_min parameter, which acts as a regularizer for the effect sizes and compensates for noise in the LD matrix. 92 | Set to "infer" to derive this parameter from the properties of the LD matrix itself. 93 | --n-components N_COMPONENTS 94 | The number of non-null Gaussian mixture components to use with the VIPRSMix model (i.e. excluding the spike component). 95 | --max-iter MAX_ITER The maximum number of iterations to run the coordinate ascent algorithm. 96 | --h2-est H2_EST The estimated heritability of the trait. If available, this value can be used for parameter initialization or hyperparameter grid 97 | search. 98 | --h2-se H2_SE The standard error for the heritability estimate for the trait. If available, this value can be used for parameter initialization or 99 | hyperparameter grid search. 100 | --hyp-search {GS,BMA,EM} 101 | The strategy for tuning the hyperparameters of the model. Options are EM (Expectation-Maximization), GS (Grid search), and BMA (Bayesian 102 | Model Averaging). 103 | --grid-metric {ELBO,validation,pseudo_validation} 104 | The metric for selecting best performing model in grid search. 105 | --grid-search-mode {pathwise,independent} 106 | The mode for grid search. Pathwise mode updates the hyperparameters sequentially and in a warm-start fashion, while independent mode 107 | updates each model separately starting from same initialization. 108 | --prop-train PROP_TRAIN 109 | The proportion of the samples to use for training when performing cross validation using the PUMAS procedure. 110 | --pi-grid PI_GRID A comma-separated grid values for the hyperparameter pi (see also --pi-steps). 111 | --pi-steps PI_STEPS The number of steps for the (default) pi grid. This will create an equidistant grid between 10/M and 0.2 on a log10 scale, where M is 112 | the number of variants. 113 | --sigma-epsilon-grid SIGMA_EPSILON_GRID 114 | A comma-separated grid values for the hyperparameter sigma_epsilon (see also --sigma-epsilon-steps). 115 | --sigma-epsilon-steps SIGMA_EPSILON_STEPS 116 | The number of steps (unique values) for the sigma_epsilon grid. 117 | --lambda-min-steps LAMBDA_MIN_STEPS 118 | The number of grid steps for the lambda_min grid. Lambda_min is used to compensate for noise in the LD matrix and acts as an extra 119 | regularizer for the effect sizes. 120 | --genomewide Fit all chromosomes jointly 121 | --exclude-lrld Exclude Long Range LD (LRLD) regions during inference. These regions can cause numerical instabilities in some cases. 122 | --backend {plink,xarray} 123 | The backend software used for computations on the genotype matrix. 124 | --n-jobs N_JOBS The number of processes to launch for the hyperparameter search (default is 1, but we recommend increasing this depending on system 125 | capacity). 126 | --threads THREADS The number of threads to use in the E-Step of VIPRS. 127 | --output-profiler-metrics 128 | Output the profiler metrics that measure runtime, memory usage, etc. 129 | --log-level {DEBUG,WARNING,INFO,ERROR,CRITICAL} 130 | The logging level for the console output. 131 | --seed SEED The random seed to use for the random number generator. 132 | 133 | ``` -------------------------------------------------------------------------------- /viprs/model/gridsearch/VIPRSGrid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import copy 4 | from tqdm import tqdm 5 | from tqdm.contrib.logging import logging_redirect_tqdm 6 | 7 | from ..VIPRS import VIPRS 8 | from ...utils.exceptions import OptimizationDivergence 9 | 10 | # Set up the logger: 11 | import logging 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class VIPRSGrid(VIPRS): 16 | """ 17 | A class to fit the `VIPRS` model to data using a grid of hyperparameters. 18 | Instead of having a single set of hyperparameters, we simultaneously fit 19 | multiple models with different hyperparameters and compare their performance 20 | at the end. The models with different hyperparameters are fit serially and in 21 | a pathwise manner, meaning that fit one model at a time and use its inferred parameters 22 | to initialize the next model. 23 | 24 | The class inherits all the basic attributes from the [VIPRS][viprs.model.VIPRS.VIPRS] class. 25 | 26 | :ivar grid_table: A pandas table containing the hyperparameters for each model. 27 | :ivar validation_result: A pandas table summarizing the performance of each model. 28 | :ivar optim_results: A list of optimization results for each model. 29 | :ivar n_models: The number of models to fit. 30 | 31 | """ 32 | 33 | def __init__(self, 34 | gdl, 35 | grid, 36 | **kwargs): 37 | """ 38 | Initialize the `VIPRS` model with a grid of hyperparameters. 39 | 40 | :param gdl: An instance of `GWADataLoader` 41 | :param grid: An instance of `HyperparameterGrid` 42 | :param kwargs: Additional keyword arguments to pass to the parent `VIPRS` class. 43 | """ 44 | 45 | self.grid_table = grid.to_table() 46 | 47 | # Placeholders: 48 | self.n_models = len(self.grid_table) 49 | self.validation_result = None 50 | self.optim_results = None 51 | 52 | self._reset_search() 53 | 54 | super().__init__(gdl, **kwargs) 55 | 56 | def _reset_search(self): 57 | """ 58 | Reset the grid search object. This might be useful after 59 | fitting the model and performing model selection/BMA, to start over. 60 | """ 61 | self.n_models = len(self.grid_table) 62 | assert self.n_models > 1, "Grid search requires at least 2 models." 63 | self.validation_result = None 64 | self.optim_results = [] 65 | 66 | @property 67 | def models_to_keep(self): 68 | """ 69 | :return: A boolean array indicating which models have converged successfully. 70 | """ 71 | return np.logical_or(~self.terminated_models, self.converged_models) 72 | 73 | @property 74 | def converged_models(self): 75 | """ 76 | :return: A boolean array indicating which models have converged successfully. 77 | """ 78 | return np.array([optr.success for optr in self.optim_results]) 79 | 80 | @property 81 | def terminated_models(self): 82 | """ 83 | :return: A boolean array indicating which models have terminated. 84 | """ 85 | return np.array([optr.stop_iteration for optr in self.optim_results]) 86 | 87 | @property 88 | def valid_terminated_models(self): 89 | """ 90 | :return: A boolean array indicating which models have terminated without error. 91 | """ 92 | return np.array([optr.valid_optim_result for optr in self.optim_results]) 93 | 94 | def to_validation_table(self): 95 | """ 96 | :return: The validation table summarizing the performance of each model. 97 | :raises ValueError: if the validation result is not set. 98 | """ 99 | 100 | if self.validation_result is None or len(self.validation_result) < 1: 101 | raise ValueError("Validation result is not set!") 102 | 103 | return pd.DataFrame(self.validation_result) 104 | 105 | def write_validation_result(self, v_filename, sep="\t"): 106 | """ 107 | After performing hyperparameter search, write a table 108 | that records that value of the objective for each combination 109 | of hyperparameters. 110 | :param v_filename: The filename for the validation table. 111 | :param sep: The separator for the validation table 112 | """ 113 | 114 | v_df = self.to_validation_table() 115 | v_df.to_csv(v_filename, index=False, sep=sep) 116 | 117 | def init_optim_meta(self): 118 | """ 119 | Initialize the various quantities/objects to keep track of the optimization process. 120 | This method initializes the "history" object (which keeps track of the objective + other 121 | hyperparameters requested by the user), in addition to the OptimizeResult objects. 122 | """ 123 | super().init_optim_meta() 124 | 125 | # Reset the OptimizeResult objects: 126 | self.optim_results = [] 127 | 128 | def fit(self, 129 | pathwise=True, 130 | **fit_kwargs): 131 | """ 132 | Fit the VIPRS model to the data using a grid of hyperparameters. 133 | The method fits multiple models with different hyperparameters and compares their performance 134 | at the end. By default, the models with different hyperparameters are fit serially and 135 | in a pathwise manner, meaning that fit one model at a time and use its inferred 136 | parameters to initialize the next model. The user can also fit the models independently by 137 | setting `pathwise=False`. 138 | 139 | :param pathwise: Whether to fit the models in a pathwise manner. Default is `True`. 140 | :param fit_kwargs: Additional keyword arguments to pass to fit method of the parent `VIPRS` class. 141 | 142 | :return: An instance of the `VIPRSGrid` class. 143 | """ 144 | 145 | if self.n_models == 1: 146 | return super().fit(**fit_kwargs) 147 | 148 | # ----------------------------------------------------------------------- 149 | # Setup the parameters that need to be tracked: 150 | 151 | var_gamma = {c: np.empty((size, self.n_models), dtype=self.float_precision) 152 | for c, size in self.shapes.items()} 153 | var_mu = {c: np.empty((size, self.n_models), dtype=self.float_precision) 154 | for c, size in self.shapes.items()} 155 | var_tau = {c: np.empty((size, self.n_models), dtype=self.float_precision) 156 | for c, size in self.shapes.items()} 157 | q = {c: np.empty((size, self.n_models), dtype=self.float_precision) 158 | for c, size in self.shapes.items()} 159 | 160 | sigma_epsilon = np.empty(self.n_models, dtype=self.float_precision) 161 | pi = np.empty(self.n_models, dtype=self.float_precision) 162 | sigma_g = np.empty(self.n_models, dtype=self.float_precision) 163 | tau_beta = np.empty(self.n_models, dtype=self.float_precision) 164 | 165 | elbos = np.empty(self.n_models, dtype=self.float_precision) 166 | 167 | # ----------------------------------------------------------------------- 168 | 169 | # Get a list of fixed hyperparameters from the grid table: 170 | params = self.grid_table.to_dict(orient='records') 171 | orig_threads = self.threads 172 | optim_results = [] 173 | history = [] 174 | 175 | # If the model is fit over a single chromosome, append this information to the 176 | # tqdm progress bar: 177 | if len(self.shapes) == 1: 178 | chrom, num_snps = list(self.shapes.items())[0] 179 | desc = f"Grid search | Chromosome {chrom} ({num_snps} variants)" 180 | else: 181 | desc = None 182 | 183 | disable_pbar = fit_kwargs.pop('disable_pbar', False) 184 | restart = not pathwise 185 | 186 | with logging_redirect_tqdm(loggers=[logger]): 187 | 188 | # Set up the progress bar for grid search: 189 | pbar = tqdm(range(self.n_models), 190 | total=self.n_models, 191 | disable=disable_pbar, 192 | desc=desc) 193 | 194 | for i in pbar: 195 | 196 | # Fix the new set of hyperparameters: 197 | self.set_fixed_params(params[i]) 198 | 199 | # Perform model fit: 200 | super().fit(continued=i > 0 and not restart, 201 | disable_pbar=True, 202 | **fit_kwargs) 203 | 204 | # Save the optimization result: 205 | optim_results.append(copy.deepcopy(self.optim_result)) 206 | # Reset the optimization result: 207 | self.optim_result.reset() 208 | self.threads = orig_threads 209 | 210 | elbos[i] = self.history['ELBO'][-1] 211 | 212 | pbar.set_postfix({'ELBO': f"{self.history['ELBO'][-1]:.4f}", 213 | 'Models Terminated': f"{i+1}/{self.n_models}"}) 214 | 215 | # Update the saved parameters: 216 | for c in self.shapes: 217 | var_gamma[c][:, i] = self.var_gamma[c] 218 | var_mu[c][:, i] = self.var_mu[c] 219 | var_tau[c][:, i] = self.var_tau[c] 220 | q[c][:, i] = self.q[c] 221 | 222 | sigma_epsilon[i] = self.sigma_epsilon 223 | pi[i] = self.pi 224 | sigma_g[i] = self._sigma_g 225 | tau_beta[i] = self.tau_beta 226 | 227 | # Update the total number of iterations: 228 | self.optim_result.nit = np.sum([optr.nit for optr in self.optim_results]) 229 | self.optim_results = optim_results 230 | 231 | # ----------------------------------------------------------------------- 232 | # Update the object attributes: 233 | self.var_gamma = var_gamma 234 | self.var_mu = var_mu 235 | self.var_tau = var_tau 236 | self.q = q 237 | self.eta = self.compute_eta() 238 | self.zeta = self.compute_zeta() 239 | self._log_var_tau = {c: np.log(self.var_tau[c]) for c in self.var_tau} 240 | 241 | # Update posterior moments: 242 | self.update_posterior_moments() 243 | 244 | # Hyperparameters: 245 | self.sigma_epsilon = sigma_epsilon 246 | self.pi = pi 247 | self._sigma_g = sigma_g 248 | self.tau_beta = tau_beta 249 | 250 | # ----------------------------------------------------------------------- 251 | 252 | # Population the validation result: 253 | self.validation_result = self.grid_table.copy() 254 | self.validation_result['ELBO'] = elbos 255 | self.validation_result['Converged'] = self.converged_models 256 | self.validation_result['Optimization_message'] = [optr.message for optr in self.optim_results] 257 | 258 | return self 259 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, Extension, find_packages 2 | from extension_helpers import add_openmp_flags_if_available 3 | from extension_helpers._openmp_helpers import check_openmp_support 4 | import pkgconfig 5 | import numpy as np 6 | import warnings 7 | import os 8 | 9 | 10 | try: 11 | from Cython.Build import cythonize 12 | except ImportError: 13 | warnings.warn("Cython not found.") 14 | cythonize = None 15 | 16 | # ------------------------------------------------------ 17 | # Find and set BLAS-related flags and paths: 18 | 19 | 20 | def find_blas_libraries(): 21 | """ 22 | Find BLAS libraries on the system using pkg-config. 23 | This function will return the include directories (compiler flags) 24 | and the linker flags to enable building the C/C++/Cython extensions 25 | that require BLAS (or whose performance would be enhanced with BLAS). 26 | 27 | We use pkg-config (as encapsulated in the `pkgconfig` Python package) 28 | to perform this search. Note that we augment the pkg-config 29 | search path with the conda library path (if available) to 30 | enable linking against BLAS libraries installed via Conda. 31 | 32 | :return: A dictionary with the following keys: 33 | * 'found': A boolean indicating whether BLAS libraries were found. 34 | * 'include_dirs': A list of include directories (compiler flags). 35 | * 'extra_link_args': A list of linker flags. 36 | * 'define_macros': A list of macros to define. 37 | * 'libraries': A list of libraries to link against. 38 | """ 39 | 40 | # STEP 0: Get the current pkg-config search path: 41 | current_pkg_config_path = os.getenv("PKG_CONFIG_PATH", "") 42 | 43 | # STEP 1: Augment the pkg-config search path with 44 | # the path of the current Conda environment (if exists). 45 | # This can leverage BLAS libraries installed via Conda. 46 | 47 | conda_path = os.getenv("CONDA_PREFIX") 48 | 49 | if conda_path is not None: 50 | conda_pkgconfig_path = os.path.join(conda_path, 'lib/pkgconfig') 51 | if os.path.isdir(conda_pkgconfig_path): 52 | current_pkg_config_path += ":" + conda_pkgconfig_path 53 | 54 | # STEP 2: Add the updated path to the environment variable: 55 | os.environ["PKG_CONFIG_PATH"] = current_pkg_config_path 56 | 57 | # STEP 3: Get all pkg-config packages and filter to 58 | # those that have "blas" in the name. 59 | blas_packages = [pkg for pkg in pkgconfig.list_all() 60 | if "blas" in pkg] 61 | 62 | # First check: Make sure that compiler flags are defined and a 63 | # valid cblas.h header file exists in the include directory: 64 | if len(blas_packages) >= 1: 65 | 66 | blas_packages = [pkg for pkg in blas_packages 67 | if pkgconfig.cflags(pkg) and 68 | os.path.isfile(os.path.join(pkgconfig.variables(pkg)['includedir'], 'cblas.h'))] 69 | 70 | # If there remains more than one library after the previous 71 | # search and filtering steps, then apply some heuristics 72 | # to select the most relevant one: 73 | if len(blas_packages) > 1: 74 | # Check if the information about the most relevant library 75 | # can be inferred from numpy. Note that this interface from 76 | # numpy changes quite often between versions, so it's not 77 | # a reliable check. But in case it works on some systems, 78 | # we use it to link to the same library as numpy: 79 | try: 80 | for pkg in blas_packages: 81 | if pkg in np.__config__.get_info('blas_opt')['libraries']: 82 | blas_packages = [pkg] 83 | break 84 | except (KeyError, AttributeError): 85 | pass 86 | 87 | # If there are still multiple libraries, then apply some 88 | # additional heuristics (based on name matching) to select 89 | # the most relevant one. Some libraries (e.g. flexiblas) are published with support for 64bit 90 | # and they expose libraries for non-BLAS API (with the _api suffix). 91 | # Ignore these here if that is the case? 92 | if len(blas_packages) > 1: 93 | # Some libraries (e.g. flexiblas) are published with support for 64bit 94 | # and they expose libraries for non-BLAS API (with the _api suffix). 95 | # Ignore these here if that is the case? 96 | 97 | idx_to_remove = set() 98 | 99 | for pkg1 in blas_packages: 100 | if pkg1 != 'blas': 101 | for i, pkg2 in enumerate(blas_packages): 102 | if pkg1 != pkg2 and pkg1 in pkg2: 103 | idx_to_remove.add(i) 104 | 105 | blas_packages = [pkg for i, pkg in enumerate(blas_packages) if i not in idx_to_remove] 106 | 107 | # After applying all the heuristics, out of all the remaining libraries, 108 | # select the first one in the list. Not the greatest solution, maybe 109 | # down the line we can use the same BLAS order as numpy. 110 | if len(blas_packages) >= 1: 111 | final_blas_pkg = blas_packages[0] 112 | else: 113 | final_blas_pkg = None 114 | 115 | # STEP 4: If a relevant BLAS package was found, extract the flags 116 | # needed for building the Cython/C/C++ extensions: 117 | 118 | if final_blas_pkg is not None: 119 | blas_info = pkgconfig.parse(final_blas_pkg) 120 | blas_info['define_macros'] = [('HAVE_CBLAS', None)] 121 | else: 122 | blas_info = { 123 | 'include_dirs': [], 124 | 'library_dirs': [], 125 | 'libraries': [], 126 | 'define_macros': [], 127 | } 128 | warnings.warn(""" 129 | ********************* WARNING ********************* 130 | BLAS library header files not found on your system. 131 | This may slow down some computations. If you are 132 | using conda, we recommend installing BLAS libraries 133 | beforehand. 134 | ********************* WARNING ********************* 135 | """, stacklevel=2) 136 | 137 | return blas_info 138 | 139 | 140 | blas_flags = find_blas_libraries() 141 | 142 | # ------------------------------------------------------ 143 | # Build cython extensions: 144 | 145 | 146 | def no_cythonize(cy_extensions, **_ignore): 147 | """ 148 | Copied from: 149 | https://cython.readthedocs.io/en/latest/src/userguide/source_files_and_compilation.html#distributing-cython-modules 150 | """ 151 | 152 | for ext in cy_extensions: 153 | sources = [] 154 | for s_file in ext.sources: 155 | path, ext = os.path.splitext(s_file) 156 | if ext in (".pyx", ".py"): 157 | if ext.language == "c++": 158 | ext = ".cpp" 159 | else: 160 | ext = ".c" 161 | s_file = path + ext 162 | sources.append(s_file) 163 | ext.sources[:] = sources 164 | 165 | return extensions 166 | 167 | 168 | extensions = [ 169 | Extension("viprs.utils.math_utils", 170 | ["viprs/utils/math_utils.pyx"], 171 | libraries=[[], ["m"]][os.name != 'nt'], # Only include for non-Windows systems 172 | include_dirs=[np.get_include()], 173 | extra_compile_args=["-O3"]), 174 | Extension("viprs.model.vi.e_step_cpp", 175 | ["viprs/model/vi/e_step_cpp.pyx"], 176 | language="c++", 177 | libraries=blas_flags['libraries'], 178 | include_dirs=[np.get_include()] + blas_flags['include_dirs'], 179 | library_dirs=blas_flags['library_dirs'], 180 | define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")] + blas_flags['define_macros'], 181 | extra_compile_args=["-O3", "-std=c++17"]) 182 | ] 183 | 184 | if check_openmp_support(): 185 | # Add any extension that requires openMP here: 186 | openmp_extensions = ['viprs.model.vi.e_step_cpp'] 187 | 188 | for omp_ext in extensions: 189 | if omp_ext.name in openmp_extensions: 190 | add_openmp_flags_if_available(omp_ext) 191 | else: 192 | warnings.warn(""" 193 | ******************** WARNING ******************** 194 | OpenMP library not found on your system. This 195 | means that some computations may be slower than 196 | expected. It will preclude using multithreading 197 | in the coordinate ascent optimization algorithm. 198 | ******************** WARNING ******************** 199 | """) 200 | 201 | 202 | if cythonize is not None: 203 | compiler_directives = { 204 | "language_level": 3, 205 | "embedsignature": True, 206 | 'boundscheck': False, 207 | 'wraparound': False, 208 | 'nonecheck': False, 209 | 'cdivision': True 210 | } 211 | extensions = cythonize(extensions, compiler_directives=compiler_directives) 212 | else: 213 | extensions = no_cythonize(extensions) 214 | 215 | # ------------------------------------------------------ 216 | # Read description/dependencies from file: 217 | 218 | with open("README.md", "r", encoding="utf-8") as fh: 219 | long_description = fh.read() 220 | 221 | with open("requirements.txt") as fp: 222 | install_requires = fp.read().strip().split("\n") 223 | 224 | with open("requirements-optional.txt") as fp: 225 | opt_requires = fp.read().strip().split("\n") 226 | 227 | with open("requirements-test.txt") as fp: 228 | test_requires = fp.read().strip().split("\n") 229 | 230 | with open("requirements-docs.txt") as fp: 231 | doc_requires = fp.read().strip().split("\n") 232 | 233 | # ------------------------------------------------------ 234 | 235 | setup( 236 | name="viprs", 237 | version="0.1.3", 238 | author="Shadi Zabad", 239 | author_email="shadi.zabad@mail.mcgill.ca", 240 | description="Variational Inference of Polygenic Risk Scores (VIPRS)", 241 | long_description=long_description, 242 | long_description_content_type="text/markdown", 243 | url="https://github.com/shz9/viprs", 244 | classifiers=[ 245 | 'Programming Language :: Python', 246 | 'Intended Audience :: Developers', 247 | 'Intended Audience :: Science/Research', 248 | 'Topic :: Software Development :: Libraries :: Python Modules', 249 | 'Topic :: Scientific/Engineering', 250 | 'Operating System :: OS Independent', 251 | 'Programming Language :: Python :: 3', 252 | 'Programming Language :: Python :: 3.8', 253 | 'Programming Language :: Python :: 3.9', 254 | 'Programming Language :: Python :: 3.10', 255 | 'Programming Language :: Python :: 3.11', 256 | 'Programming Language :: Python :: 3.12' 257 | ], 258 | package_dir={'': '.'}, 259 | packages=find_packages(), 260 | python_requires=">=3.8", 261 | package_data={'viprs': ['model/vi/*.pxd', 'utils/*.pxd']}, 262 | scripts=['bin/viprs_fit', 'bin/viprs_score', 'bin/viprs_evaluate'], 263 | install_requires=install_requires, 264 | extras_require={'opt': opt_requires, 'test': test_requires, 'docs': doc_requires}, 265 | ext_modules=extensions, 266 | zip_safe=False 267 | ) 268 | -------------------------------------------------------------------------------- /viprs/model/gridsearch/HyperparameterGrid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import itertools 4 | 5 | 6 | class HyperparameterGrid(object): 7 | """ 8 | A utility class to facilitate generating grids for the 9 | hyperparameters of the standard `VIPRS` models. It is designed to 10 | interface with models that operate on grids of hyperparameters, 11 | such as `VIPRSGridSeach` and `VIPRSBMA`. The hyperparameters for 12 | the standard VIPRS model are: 13 | 14 | * `sigma_epsilon`: The residual variance for the phenotype. 15 | * `tau_beta`: The precision (inverse variance) of the prior for the effect sizes. 16 | * `pi`: The proportion of non-zero effect sizes (polygenicity). 17 | * `lambda_min`: The extra ridge penalty that compensates for the non-PSD nature of the LD matrix. 18 | 19 | :ivar sigma_epsilon: A grid of values for the residual variance hyperparameter. 20 | :ivar tau_beta: A grid of values for the precision of the prior for the effect sizes. 21 | :ivar pi: A grid of values for the proportion of non-zero effect sizes. 22 | :ivar lambda_min: A grid of values for the extra ridge penalty that compensates for 23 | the non-PSD nature of the LD matrix. 24 | :ivar h2_est: An estimate of the heritability for the trait under consideration. 25 | :ivar h2_se: The standard error of the heritability estimate. 26 | :ivar n_snps: The number of common variants that may be relevant for this analysis. 27 | 28 | """ 29 | 30 | def __init__(self, 31 | sigma_epsilon_grid=None, 32 | sigma_epsilon_steps=None, 33 | tau_beta_grid=None, 34 | tau_beta_steps=None, 35 | pi_grid=None, 36 | pi_steps=None, 37 | lambda_min_grid=None, 38 | lambda_min_steps=None, 39 | h2_est=None, 40 | h2_se=None, 41 | n_snps=1e6): 42 | """ 43 | 44 | Create a hyperparameter grid for the standard VIPRS model with the 45 | spike-and-slab prior. The hyperparameters for this model are: 46 | 47 | * `sigma_epsilon`: The residual variance 48 | * `tau_beta`: The precision (inverse variance) of the prior for the effect sizes 49 | * `pi`: The proportion of non-zero effect sizes 50 | * `lambda_min`: The extra ridge penalty that compensates for the non-PSD LD matrices. 51 | 52 | For each of these hyperparameters, we can provide a grid of values to search over. 53 | If the heritability estimate and standard error (from e.g. LDSC) are provided, 54 | we can generate grids for sigma_epsilon and tau_beta that are informed by these estimates. 55 | 56 | For each hyperparameter to be included in the grid, user must specify either the grid 57 | itself, or the number of steps to use to generate the grid. 58 | 59 | :param sigma_epsilon_grid: An array containing a grid of values for the sigma_epsilon hyperparameter. 60 | :param sigma_epsilon_steps: The number of steps for the sigma_epsilon grid 61 | :param tau_beta_grid: An array containing a grid of values for the tau_beta hyperparameter. 62 | :param tau_beta_steps: The number of steps for the tau_beta grid 63 | :param pi_grid: An array containing a grid of values for the pi hyperparameter 64 | :param pi_steps: The number of steps for the pi grid 65 | :param h2_est: An estimate of the heritability for the trait under consideration. If provided, 66 | we can generate grids for some of the hyperparameters that are consistent with this estimate. 67 | :param h2_se: The standard error of the heritability estimate. If provided, we can generate grids 68 | for some of the hyperparameters that are consistent with this estimate. 69 | :param n_snps: Number of common variants that may be relevant for this analysis. This estimate can 70 | be used to generate grids that are based on this number. 71 | """ 72 | 73 | # If the heritability estimate is not provided, use a reasonable default value of 0.1 74 | # with a wide standard error of 0.1. 75 | 76 | self.h2_est = h2_est or 0.1 77 | self.h2_se = h2_se or 0.1 78 | 79 | self.n_snps = n_snps 80 | self._search_params = [] 81 | 82 | # Initialize the grid for sigma_epsilon: 83 | self.sigma_epsilon = sigma_epsilon_grid 84 | if self.sigma_epsilon is not None: 85 | self._search_params.append('sigma_epsilon') 86 | elif sigma_epsilon_steps is not None: 87 | self.generate_sigma_epsilon_grid(steps=sigma_epsilon_steps) 88 | 89 | # Initialize the grid for the tau_beta: 90 | self.tau_beta = tau_beta_grid 91 | if self.tau_beta is not None: 92 | self._search_params.append('tau_beta') 93 | elif tau_beta_steps is not None: 94 | self.generate_tau_beta_grid(steps=tau_beta_steps) 95 | 96 | # Initialize the grid for pi: 97 | self.pi = pi_grid 98 | if self.pi is not None: 99 | self._search_params.append('pi') 100 | elif pi_steps is not None: 101 | self.generate_pi_grid(steps=pi_steps) 102 | 103 | # Initialize the grid for lambda_min: 104 | self.lambda_min = lambda_min_grid 105 | if self.lambda_min is not None: 106 | self._search_params.append('lambda_min') 107 | elif lambda_min_steps is not None: 108 | self.generate_lambda_min_grid(steps=lambda_min_steps) 109 | 110 | def _generate_h2_grid(self, steps=5): 111 | """ 112 | Use the heritability estimate and standard error to generate a grid of values for 113 | the heritability parameter. Specifically, given the estimate and standard error, we 114 | generate heritability estimates from the percentiles of the normal distribution, 115 | with mean `h2_est` and standard deviation `h2_se`. The grid values range from the 10th 116 | percentile to the 90th percentile of this normal distribution. 117 | 118 | :param steps: The number of steps for the heritability grid. 119 | :return: A grid of values for the heritability parameter. 120 | 121 | """ 122 | 123 | assert steps > 0 124 | assert self.h2_est is not None 125 | 126 | # If the heritability standard error is not provided, we use half of the heritability estimate 127 | # by default. 128 | # *Justification*: Under the assumption that heritability for the trait being analyzed 129 | # is significantly greater than 0, the standard error should be, at a maximum, 130 | # half of the heritability estimate itself to get us a Z-score with absolute value 131 | # greater than 2. 132 | if self.h2_se is None: 133 | h2_se = self.h2_est * 0.5 134 | else: 135 | h2_se = self.h2_se 136 | 137 | # Sanity checking steps: 138 | assert 0. < self.h2_est < 1. 139 | assert h2_se > 0 140 | 141 | from scipy.stats import norm 142 | 143 | # First, determine the percentile boundaries to avoid producing 144 | # invalid values for the heritability grid: 145 | 146 | percentile_start = max(0.1, norm.cdf(1e-5, loc=self.h2_est, scale=h2_se)) 147 | percentile_stop = min(0.9, norm.cdf(1. - 1e-5, loc=self.h2_est, scale=h2_se)) 148 | 149 | # Generate the heritability grid: 150 | return norm.ppf(np.linspace(percentile_start, percentile_stop, steps), 151 | loc=self.h2_est, scale=h2_se) 152 | 153 | def generate_sigma_epsilon_grid(self, steps=5): 154 | """ 155 | Generate a grid of values for the `sigma_epsilon` (residual variance) hyperparameter. 156 | 157 | :param steps: The number of steps for the sigma_epsilon grid. 158 | """ 159 | 160 | assert steps > 0 161 | 162 | h2_grid = self._generate_h2_grid(steps) 163 | self.sigma_epsilon = 1. - h2_grid 164 | 165 | if 'sigma_epsilon' not in self._search_params: 166 | self._search_params.append('sigma_epsilon') 167 | 168 | def generate_tau_beta_grid(self, steps=5): 169 | """ 170 | Generate a grid of values for the `tau_beta` 171 | (precision of the prior for the effect sizes) hyperparameter. 172 | :param steps: The number of steps for the `tau_beta` grid 173 | """ 174 | 175 | assert steps > 0 176 | 177 | h2_grid = self._generate_h2_grid(steps) 178 | # Assume ~1% of SNPs are causal: 179 | self.tau_beta = 0.01*self.n_snps / h2_grid 180 | 181 | if 'tau_beta' not in self._search_params: 182 | self._search_params.append('tau_beta') 183 | 184 | def generate_pi_grid(self, steps=5, max_pi=0.2): 185 | """ 186 | Generate a grid of values for the `pi` (proportion of non-zero effect sizes) hyperparameter. 187 | :param steps: The number of steps for the `pi` grid 188 | :param max_pi: The maximum value for the `pi` grid. 189 | """ 190 | 191 | assert steps > 0 192 | 193 | min_pi = np.log10(max(10./self.n_snps, 1e-5)) 194 | # For now, we impose a limit of 10k causal variants 195 | # Need to figure out better ways to determine the maximum 196 | # value here. 197 | max_pi = np.log10(min(10000 / self.n_snps, max_pi)) 198 | 199 | assert min_pi < max_pi 200 | 201 | self.pi = np.logspace( 202 | min_pi, 203 | max_pi, 204 | steps 205 | ) 206 | 207 | if 'pi' not in self._search_params: 208 | self._search_params.append('pi') 209 | 210 | def generate_lambda_min_grid(self, steps=5, emp_lambda_min=None): 211 | """ 212 | Generate a grid of values for the `lambda_min` parameter, associated with extra ridge penalty 213 | that compensates for the non-PSD nature of the LD matrix. 214 | :param steps: The number of steps for the `lambda_min` grid. 215 | :param emp_lambda_min: The empirical value of lambda_min to use as a reference point. 216 | """ 217 | 218 | assert steps > 0 219 | 220 | self.lambda_min = np.concatenate([[0.], np.logspace(-4, 1., steps - 1)]) 221 | 222 | if emp_lambda_min is not None: 223 | self.lambda_min *= emp_lambda_min 224 | 225 | if 'lambda_min' not in self._search_params: 226 | self._search_params.append('lambda_min') 227 | 228 | def combine_grids(self): 229 | """ 230 | Weave together the different hyperparameter grids and return a list of 231 | dictionaries where the key is the hyperparameter name and the value is 232 | value for that hyperparameter. 233 | 234 | :return: A list of dictionaries containing the hyperparameter values. 235 | :raises ValueError: If all the grids are empty. 236 | 237 | """ 238 | hyp_names = [name for name, value in self.__dict__.items() 239 | if value is not None and name in self._search_params] 240 | 241 | if len(hyp_names) > 0: 242 | hyp_values = itertools.product(*[hyp_grid for hyp_name, hyp_grid in self.__dict__.items() 243 | if hyp_grid is not None and hyp_name in hyp_names]) 244 | 245 | return [dict(zip(hyp_names, hyp_v)) for hyp_v in hyp_values] 246 | else: 247 | raise ValueError("All the grids are empty!") 248 | 249 | def to_table(self): 250 | """ 251 | :return: The hyperparameter grid as a pandas `DataFrame`. 252 | :raises ValueError: If all the grids are empty. 253 | """ 254 | 255 | combined_grids = self.combine_grids() 256 | 257 | return pd.DataFrame(combined_grids) 258 | -------------------------------------------------------------------------------- /viprs/eval/binary_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .continuous_metrics import incremental_r2 3 | from .eval_utils import fit_linear_model 4 | import pandas as pd 5 | 6 | 7 | def roc_auc(true_val, pred_val): 8 | """ 9 | Compute the area under the ROC (AUROC) for a model 10 | that maps from the PRS predictions to the binary phenotype. 11 | 12 | :param true_val: The response value or phenotype (a numpy binary vector with 0s and 1s) 13 | :param pred_val: The predicted value or PRS (a numpy vector) 14 | """ 15 | from sklearn.metrics import roc_auc_score 16 | return roc_auc_score(true_val, pred_val) 17 | 18 | 19 | def pr_auc(true_val, pred_val): 20 | """ 21 | Compute the area under the Precision-Recall curve for a model 22 | that maps from the PRS predictions to the binary phenotype. 23 | 24 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 25 | :param pred_val: The predicted value or PRS (a numpy vector) 26 | """ 27 | from sklearn.metrics import precision_recall_curve, auc 28 | precision, recall, thresholds = precision_recall_curve(true_val, pred_val) 29 | return auc(recall, precision) 30 | 31 | 32 | def avg_precision(true_val, pred_val): 33 | """ 34 | Compute the average precision between the PRS predictions and a binary. 35 | 36 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 37 | :param pred_val: The predicted value or PRS (a numpy vector) 38 | """ 39 | from sklearn.metrics import average_precision_score 40 | return average_precision_score(true_val, pred_val) 41 | 42 | 43 | def f1(true_val, pred_val): 44 | """ 45 | Compute the F1 score between the PRS predictions and a binary phenotype. 46 | 47 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 48 | :param pred_val: The predicted value or PRS (a numpy vector) 49 | """ 50 | from sklearn.metrics import f1_score 51 | return f1_score(true_val, pred_val) 52 | 53 | 54 | def mcfadden_r2(true_val, pred_val, covariates=None): 55 | """ 56 | Compute the McFadden pseudo-R^2 between the PRS predictions and a phenotype. 57 | If covariates are provided, we compute the incremental pseudo-R^2 by conditioning 58 | on the covariates. 59 | 60 | 61 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 62 | :param pred_val: The predicted value or PRS (a numpy vector) 63 | :param covariates: A pandas table of covariates where the rows are ordered 64 | the same way as the predictions and response. 65 | """ 66 | 67 | if covariates is None: 68 | add_intercept = False 69 | covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) 70 | else: 71 | add_intercept = True 72 | 73 | null_result = fit_linear_model(true_val, covariates, 74 | family='binomial', add_intercept=add_intercept) 75 | full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), 76 | family='binomial', add_intercept=add_intercept) 77 | 78 | return 1. - (full_result.llf / null_result.llf) 79 | 80 | 81 | def cox_snell_r2(true_val, pred_val, covariates=None): 82 | """ 83 | Compute the Cox-Snell pseudo-R^2 between the PRS predictions and a binary phenotype. 84 | If covariates are provided, we compute the incremental pseudo-R^2 by conditioning 85 | on the covariates. 86 | 87 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 88 | :param pred_val: The predicted value or PRS (a numpy vector) 89 | :param covariates: A pandas table of covariates where the rows are ordered 90 | the same way as the predictions and response. 91 | """ 92 | 93 | if covariates is None: 94 | add_intercept = False 95 | covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) 96 | else: 97 | add_intercept = True 98 | 99 | null_result = fit_linear_model(true_val, covariates, 100 | family='binomial', add_intercept=add_intercept) 101 | full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), 102 | family='binomial', add_intercept=add_intercept) 103 | n = true_val.shape[0] 104 | 105 | return 1. - np.exp(-2 * (full_result.llf - null_result.llf) / n) 106 | 107 | 108 | def nagelkerke_r2(true_val, pred_val, covariates=None): 109 | """ 110 | Compute the Nagelkerke pseudo-R^2 between the PRS predictions and a binary phenotype. 111 | If covariates are provided, we compute the incremental pseudo-R^2 by conditioning 112 | on the covariates. 113 | 114 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 115 | :param pred_val: The predicted value or PRS (a numpy vector) 116 | :param covariates: A pandas table of covariates where the rows are ordered 117 | the same way as the predictions and response. 118 | """ 119 | 120 | if covariates is None: 121 | add_intercept = False 122 | covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) 123 | else: 124 | add_intercept = True 125 | 126 | null_result = fit_linear_model(true_val, covariates, 127 | family='binomial', add_intercept=add_intercept) 128 | full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), 129 | family='binomial', add_intercept=add_intercept) 130 | n = true_val.shape[0] 131 | 132 | # First compute the Cox & Snell R2: 133 | cox_snell = 1. - np.exp(-2 * (full_result.llf - null_result.llf) / n) 134 | 135 | # Then scale it by the maximum possible R2: 136 | return cox_snell / (1. - np.exp(2 * null_result.llf / n)) 137 | 138 | 139 | def liability_r2(true_val, pred_val, covariates=None, return_all_r2=False): 140 | """ 141 | Compute the coefficient of determination (R^2) on the liability scale 142 | according to Lee et al. (2012) Gene. Epi. 143 | https://pubmed.ncbi.nlm.nih.gov/22714935/ 144 | 145 | The R^2 liability is defined as: 146 | R_{liability}^2 = R2_{observed}*K*(K-1)/(z^2) 147 | 148 | where R_{observed}^2 is the R^2 on the observed scale and K is the sample prevalence 149 | and z is the "height of the normal density at the quantile for K". 150 | 151 | If covariates are provided, we compute the incremental pseudo-R^2 by conditioning 152 | on the covariates. 153 | 154 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 155 | :param pred_val: The predicted value or PRS (a numpy vector) 156 | :param covariates: A pandas table of covariates where the rows are ordered 157 | the same way as the predictions and response. 158 | :param return_all_r2: If True, return the null, full and incremental R2 values. 159 | """ 160 | 161 | # First, obtain the incremental R2 on the observed scale: 162 | r2_obs = incremental_r2(true_val, pred_val, covariates, return_all_r2=return_all_r2) 163 | 164 | # Second, compute the prevalence and the standard normal quantile of the prevalence: 165 | 166 | from scipy.stats import norm 167 | 168 | k = np.mean(true_val) 169 | z2 = norm.pdf(norm.ppf(1.-k))**2 170 | mult_factor = k*(1. - k) / z2 171 | 172 | if return_all_r2: 173 | return { 174 | 'Null_R2': r2_obs['Null_R2']*mult_factor, 175 | 'Full_R2': r2_obs['Full_R2']*mult_factor, 176 | 'Incremental_R2': r2_obs['Incremental_R2']*mult_factor 177 | } 178 | else: 179 | return r2_obs * mult_factor 180 | 181 | 182 | def liability_probit_r2(true_val, pred_val, covariates=None, return_all_r2=False): 183 | """ 184 | Compute the R^2 between the PRS predictions and a binary phenotype on the liability 185 | scale using the probit likelihood as outlined in Lee et al. (2012) Gene. Epi. 186 | https://pubmed.ncbi.nlm.nih.gov/22714935/ 187 | 188 | The R^2 is defined as: 189 | R2_{probit} = Var(pred) / (Var(pred) + 1) 190 | 191 | Where Var(pred) is the variance of the predicted liability. 192 | 193 | If covariates are provided, we compute the incremental pseudo-R^2 by conditioning 194 | on the covariates. 195 | 196 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 197 | :param pred_val: The predicted value or PRS (a numpy vector) 198 | :param covariates: A pandas table of covariates where the rows are ordered 199 | the same way as the predictions and response. 200 | :param return_all_r2: If True, return the null, full and incremental R2 values. 201 | """ 202 | 203 | if covariates is None: 204 | add_intercept = False 205 | covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) 206 | else: 207 | add_intercept = True 208 | 209 | null_result = fit_linear_model(true_val, covariates, 210 | family='binomial', link='probit', add_intercept=add_intercept) 211 | full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), 212 | family='binomial', link='probit', add_intercept=add_intercept) 213 | 214 | null_var = np.var(null_result.predict()) 215 | null_r2 = null_var / (null_var + 1.) 216 | 217 | full_var = np.var(full_result.predict()) 218 | full_r2 = full_var / (full_var + 1.) 219 | 220 | if return_all_r2: 221 | return { 222 | 'Null_R2': null_r2, 223 | 'Full_R2': full_r2, 224 | 'Incremental_R2': full_r2 - null_r2 225 | } 226 | else: 227 | return full_r2 - null_r2 228 | 229 | 230 | def liability_logit_r2(true_val, pred_val, covariates=None, return_all_r2=False): 231 | """ 232 | Compute the R^2 between the PRS predictions and a binary phenotype on the liability 233 | scale using the logit likelihood as outlined in Lee et al. (2012) Gene. Epi. 234 | https://pubmed.ncbi.nlm.nih.gov/22714935/ 235 | 236 | The R^2 is defined as: 237 | R2_{logit} = Var(pred) / (Var(pred) + pi^2 / 3) 238 | 239 | Where Var(pred) is the variance of the predicted liability. 240 | 241 | If covariates are provided, we compute the incremental pseudo-R^2 by conditioning 242 | on the covariates. 243 | 244 | :param true_val: The response value or phenotype (a binary numpy vector with 0s and 1s) 245 | :param pred_val: The predicted value or PRS (a numpy vector) 246 | :param covariates: A pandas table of covariates where the rows are ordered 247 | the same way as the predictions and response. 248 | :param return_all_r2: If True, return the null, full and incremental R2 values. 249 | """ 250 | 251 | if covariates is None: 252 | add_intercept = False 253 | covariates = pd.DataFrame(np.ones((true_val.shape[0], 1)), columns=['const']) 254 | else: 255 | add_intercept = True 256 | 257 | null_result = fit_linear_model(true_val, covariates, 258 | family='binomial', add_intercept=add_intercept) 259 | full_result = fit_linear_model(true_val, covariates.assign(pred_val=pred_val), 260 | family='binomial', add_intercept=add_intercept) 261 | 262 | null_var = np.var(null_result.predict()) 263 | null_r2 = null_var / (null_var + (np.pi**2 / 3)) 264 | 265 | full_var = np.var(full_result.predict()) 266 | full_r2 = full_var / (full_var + (np.pi**2 / 3)) 267 | 268 | if return_all_r2: 269 | return { 270 | 'Null_R2': null_r2, 271 | 'Full_R2': full_r2, 272 | 'Incremental_R2': full_r2 - null_r2 273 | } 274 | else: 275 | return full_r2 - null_r2 276 | -------------------------------------------------------------------------------- /notebooks/height_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "pycharm": { 7 | "name": "#%% md\n" 8 | } 9 | }, 10 | "source": [ 11 | "# Fitting VIPRS Model on GWAS data for Standing Height from the UK Biobank" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "This example illustrate how to fit the `VIPRS` model on external GWAS summary statistics \n", 19 | "from the `fastGWA` catalog. The `fastGWA` catalog is a comprehensive GWAS resource on \n", 20 | "thousands of phenotypes from the UK Biobank. In this example, we will walk the user \n", 21 | "through 4 important steps in fitting PRS models to publicly available GWAS summary data:\n", 22 | "\n", 23 | "1. **Data pre-processing**: Download the GWAS summary statistics for height and **match** them to genotype data for European samples from the 1000G project. The genotype data is restricted to about 15,000 variants on chromosome 22 for now.\n", 24 | "\n", 25 | "2. **Compute LD matrices**: After the GWAS data is downloaded and harmonized with the genotype data, we will compute Linkage-Disequilibrium (LD) matrices that will be used in model fitting. In most applications, it suffices to use publicly available LD matrices, but this example will illustrate how to compute these matrices from genotype data.\n", 26 | "\n", 27 | "3. **Model fit**: After the data is preprocessed and we have the LD matrices computed, we will fit the `VIPRS` model to the data. This will result in a set of inferred effect sizes for each of the 15,000 variants.\n", 28 | "\n", 29 | "4. **Prediction**: After the model is fit, we will predict (sometimes called scoring or linear scoring) height for the 1000G samples. Unfortunately, we don't have real phenotypes for those samples, so we can't evaluate accuracy, but we can inspect the distribution of polygenic scores, etc.\n", 30 | "\n", 31 | "But first things first, let's import the needed packages to run this analysis:" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 42, 37 | "metadata": { 38 | "ExecuteTime": { 39 | "end_time": "2024-04-05T21:04:15.754890Z", 40 | "start_time": "2024-04-05T21:04:15.741558Z" 41 | } 42 | }, 43 | "source": [ 44 | "import numpy as np\n", 45 | "import magenpy as mgp\n", 46 | "import viprs as vp\n", 47 | "import matplotlib.pyplot as plt\n", 48 | "import warnings\n", 49 | "warnings.filterwarnings(\"ignore\") # ignore warnings" 50 | ], 51 | "outputs": [] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": {}, 56 | "source": [ 57 | "## 1) Data pre-processing & harmonization" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "First, let's load and harmonize the data using `magenpy`:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 3, 70 | "metadata": { 71 | "ExecuteTime": { 72 | "end_time": "2024-04-05T16:48:56.738961Z", 73 | "start_time": "2024-04-05T16:47:45.811770Z" 74 | }, 75 | "pycharm": { 76 | "name": "#%%\n" 77 | } 78 | }, 79 | "source": [ 80 | "# GWAS summary statistics for Standing Height from fastGWA:\n", 81 | "sumstats_url = \"https://yanglab.westlake.edu.cn/data/fastgwa_data/UKB/50.v1.1.fastGWA.gz\"\n", 82 | "\n", 83 | "# Load genotype data for European samples in the 1000G project (chromosome 22):\n", 84 | "gdl = mgp.GWADataLoader(bed_files=mgp.tgp_eur_data_path(),\n", 85 | " sumstats_files=sumstats_url,\n", 86 | " sumstats_format=\"fastGWA\")" 87 | ], 88 | "outputs": [] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## 2) Computing LD matrices:" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": { 100 | "pycharm": { 101 | "name": "#%% md\n" 102 | } 103 | }, 104 | "source": [ 105 | "Then, we use `magenpy` to compute the reference LD matrices:" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 4, 111 | "metadata": { 112 | "ExecuteTime": { 113 | "end_time": "2024-04-05T16:50:06.154841Z", 114 | "start_time": "2024-04-05T16:49:57.010351Z" 115 | }, 116 | "pycharm": { 117 | "name": "#%%\n" 118 | } 119 | }, 120 | "source": [ 121 | "# Compute LD using the shrinkage estimator (Wen and Stephens 2010):\n", 122 | "gdl.compute_ld(\"shrinkage\",\n", 123 | " output_dir=\"~/temp\", # Output directory where the LD matrix will be stored\n", 124 | " genetic_map_ne=11400, # effective population size (Ne)\n", 125 | " genetic_map_sample_size=183,\n", 126 | " threshold=1e-3)" 127 | ], 128 | "outputs": [] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## 3) Model fit" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "metadata": { 140 | "pycharm": { 141 | "name": "#%% md\n" 142 | } 143 | }, 144 | "source": [ 145 | "Next, we fit the `VIPRS` to the harmonized GWAS summary statistics data. Note that the fit will mainly be done on the variants on chromosome 22:" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "metadata": { 152 | "ExecuteTime": { 153 | "end_time": "2024-04-05T16:50:16.136153Z", 154 | "start_time": "2024-04-05T16:50:11.574996Z" 155 | }, 156 | "pycharm": { 157 | "name": "#%%\n" 158 | } 159 | }, 160 | "source": [ 161 | "# Fit VIPRS to the summary statistics:\n", 162 | "v = vp.VIPRS(gdl).fit()" 163 | ], 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "metadata": { 169 | "ExecuteTime": { 170 | "end_time": "2024-04-05T21:06:25.384444Z", 171 | "start_time": "2024-04-05T21:06:25.350656Z" 172 | } 173 | }, 174 | "source": [ 175 | "To verify that the model fit behaved as expected with no issues, we can inspect \n", 176 | "the objective (Evidence Lower BOund or `ELBO`) as a function of the number of iterations. `viprs` provides a \n", 177 | "convenience function to generate this plot:" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 45, 183 | "metadata": { 184 | "ExecuteTime": { 185 | "end_time": "2024-04-05T21:06:37.819682Z", 186 | "start_time": "2024-04-05T21:06:35.161637Z" 187 | } 188 | }, 189 | "source": [ 190 | "from viprs.plot.diagnostics import plot_history\n", 191 | "\n", 192 | "plot_history(v)" 193 | ], 194 | "outputs": [] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "metadata": {}, 199 | "source": [ 200 | "Now that the model converged, we can inspect its estimates of both the global hyperparameters as well as summaries of the posterior distribution for the effect sizes of individual variants.\n", 201 | "\n", 202 | "To obtain the estimates for some of the global hyperparameters, such as heritability, residual variance, proportion of causal variants, etc., we can simply invoke the method `.to_theta_table`:" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": 46, 208 | "metadata": { 209 | "ExecuteTime": { 210 | "end_time": "2024-04-05T21:09:58.696758Z", 211 | "start_time": "2024-04-05T21:09:58.688218Z" 212 | } 213 | }, 214 | "source": [ 215 | "v.to_theta_table()" 216 | ], 217 | "outputs": [] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "metadata": {}, 222 | "source": [ 223 | "**NOTE:** `VIPRS` is not a method used to estimate heritability or polygenicity (proportion of causal variants). However, \n", 224 | "we can obtain estimates for these quantities as part of the model fit.\n", 225 | " \n", 226 | "As for summaries of the posterior distribution, one thing we can look at is the \n", 227 | "**P**osterior **I**nclusion **P**robability (**PIP**), which is a metric that summarizes \n", 228 | "the probability that the variant of interest is causal for the phenotype of interest (e.g. Standing Height):" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 36, 234 | "metadata": { 235 | "ExecuteTime": { 236 | "end_time": "2024-04-05T19:39:29.120032Z", 237 | "start_time": "2024-04-05T19:39:28.919630Z" 238 | } 239 | }, 240 | "source": [ 241 | "# Get the inferred effect sizes:\n", 242 | "inf_effect_table = v.to_table(col_subset=('CHR', 'SNP', 'POS', 'A1', 'A2'))\n", 243 | "\n", 244 | "# Plot the PIP as a function of genomic position:\n", 245 | "\n", 246 | "plt.scatter(effect_table['POS'], effect_table['PIP'], \n", 247 | " alpha=.4, marker='.')\n", 248 | "plt.xticks([])\n", 249 | "plt.xlabel(\"Genomic Position (CHR22)\")\n", 250 | "plt.ylabel(\"PIP\")\n", 251 | "plt.title(\"Posterior Inclusion Probability for Standing Height\")" 252 | ], 253 | "outputs": [] 254 | }, 255 | { 256 | "cell_type": "markdown", 257 | "metadata": { 258 | "ExecuteTime": { 259 | "end_time": "2024-04-05T19:41:54.083251Z", 260 | "start_time": "2024-04-05T19:41:54.076668Z" 261 | } 262 | }, 263 | "source": [ 264 | "We see from this that most variants have very small probability of meaningfully contributing to Standing Height. \n", 265 | "Another illustrative thing that we can do is compare the posterior mean for the effect sizes obtained by `VIPRS` \n", 266 | "and compare it to the marginal effect sizes obtained from GWAS:" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 38, 272 | "metadata": { 273 | "ExecuteTime": { 274 | "end_time": "2024-04-05T19:40:39.852050Z", 275 | "start_time": "2024-04-05T19:40:39.623387Z" 276 | } 277 | }, 278 | "source": [ 279 | "# Get the summary statistics table:\n", 280 | "# NOTE: For the purposes of comparing effects on the same scale,\n", 281 | "# here, we get standardized BETAs, which is why VIPRS uses for inference:\n", 282 | "sumstats = gdl.to_summary_statistics_table(col_subset=('CHR', 'SNP', 'POS', 'A1', 'A2', 'STD_BETA'))\n", 283 | "\n", 284 | "# Rename the BETAs for clarity:\n", 285 | "sumstats.rename(columns={'STD_BETA': 'GWAS_BETA'}, inplace=True)\n", 286 | "effect_table.rename(columns={'BETA': 'VIPRS_BETA'}, inplace=True)\n", 287 | "\n", 288 | "# Merge the two tables:\n", 289 | "merged_table = sumstats.merge(effect_table)\n", 290 | "\n", 291 | "# Plot the results:\n", 292 | "plt.scatter(merged_table['GWAS_BETA'], \n", 293 | " merged_table['VIPRS_BETA'], \n", 294 | " alpha=.5,\n", 295 | " marker='.')\n", 296 | "plt.xlabel(\"Marginal BETA (GWAS)\")\n", 297 | "plt.ylabel(\"VIPRS Posterior Mean for BETA\")\n", 298 | "\n", 299 | "# Plot the unity line to highlight differences in magnitude:\n", 300 | "x = np.linspace(merged_table[['GWAS_BETA', 'VIPRS_BETA']].min().min(), \n", 301 | " merged_table[['GWAS_BETA', 'VIPRS_BETA']].max().max(), 100)\n", 302 | "plt.plot(x, x, c='red', ls='--')" 303 | ], 304 | "outputs": [] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "This plot is a nice illustration of the **selective shrinkage** effect that results \n", 311 | "from using sparse Bayesian priors, like the **Spike-and-Slab prior** employed by `VIPRS`. Here, the effects \n", 312 | "for most variants are shrunk towards zero, whereas the few variants that are strongly associated \n", 313 | "with the phenotype retain their effects." 314 | ] 315 | }, 316 | { 317 | "cell_type": "markdown", 318 | "metadata": { 319 | "ExecuteTime": { 320 | "end_time": "2024-04-05T17:05:12.134352Z", 321 | "start_time": "2024-04-05T17:05:12.075362Z" 322 | } 323 | }, 324 | "source": [ 325 | "## 4) Prediction / Generating polygenic scores" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "metadata": { 331 | "pycharm": { 332 | "name": "#%% md\n" 333 | } 334 | }, 335 | "source": [ 336 | "Once convergence is achieved, we are going to predict (i.e. compute polygenic scores) on the European samples in the 1000G data." 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 13, 342 | "metadata": { 343 | "ExecuteTime": { 344 | "end_time": "2024-04-05T17:12:04.390626Z", 345 | "start_time": "2024-04-05T17:12:04.069105Z" 346 | }, 347 | "pycharm": { 348 | "name": "#%%\n" 349 | } 350 | }, 351 | "source": [ 352 | "# Obtain height PGS estimates for the European samples in 1000G Project:\n", 353 | "height_pgs = v.predict()\n", 354 | "\n", 355 | "# plot distribution of height PGS:\n", 356 | "\n", 357 | "plt.hist(height_pgs)\n", 358 | "plt.xlabel(\"Height PGS\")\n", 359 | "plt.title(\"Height PGS in 1000G (EUR)\")" 360 | ], 361 | "outputs": [] 362 | } 363 | ], 364 | "metadata": { 365 | "kernelspec": { 366 | "display_name": "Python 3 (ipykernel)", 367 | "language": "python", 368 | "name": "python3" 369 | }, 370 | "language_info": { 371 | "codemirror_mode": { 372 | "name": "ipython", 373 | "version": 3 374 | }, 375 | "file_extension": ".py", 376 | "mimetype": "text/x-python", 377 | "name": "python", 378 | "nbconvert_exporter": "python", 379 | "pygments_lexer": "ipython3", 380 | "version": "3.11.5" 381 | } 382 | }, 383 | "nbformat": 4, 384 | "nbformat_minor": 1 385 | } 386 | --------------------------------------------------------------------------------