├── docs ├── _static │ └── .gitignore ├── contributing.rst ├── readme.rst ├── authors.rst ├── changelog.rst ├── license.rst ├── requirements.txt ├── Makefile ├── index.rst └── conf.py ├── mypy.ini ├── requirements.txt ├── example ├── result │ ├── test.Wm.tsv.gz │ ├── test.Zm.tsv.gz │ ├── test.Wvar.tsv.gz │ ├── test.Zvar.tsv.gz │ ├── demo_old.Wm.tsv.gz │ ├── demo_old.Zm.tsv.gz │ ├── demo_test.Wm.tsv.gz │ ├── demo_test.Zm.tsv.gz │ ├── test.factor.tsv.gz │ ├── demo_old.Wvar.tsv.gz │ ├── demo_old.Zvar.tsv.gz │ ├── demo_test.Wvar.tsv.gz │ ├── demo_test.Zvar.tsv.gz │ ├── demo_old.factor.tsv.gz │ └── demo_test.factor.tsv.gz ├── data │ ├── n20_p1k.Zscore.tsv.gz │ └── n20_p1k.SampleN.tsv ├── run_demo.sh ├── test_skeleton.py └── runFactorGo.old.py ├── requirements_dev.txt ├── AUTHORS.rst ├── CHANGELOG.rst ├── .isort.cfg ├── tests ├── conftest.py ├── utils.py └── test_factorgo.py ├── pyproject.toml ├── .github └── workflows │ └── package_install.yml ├── .readthedocs.yml ├── src └── factorgo │ ├── __init__.py │ ├── util.py │ ├── io.py │ ├── cli.py │ └── infer.py ├── .coveragerc ├── setup.py ├── .gitignore ├── LICENSE.txt ├── .pre-commit-config.yaml ├── tox.ini ├── setup.cfg ├── README.md └── CONTRIBUTING.rst /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Empty directory 2 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | allow_redefinition = true 3 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jax 2 | jaxlib 3 | numpy 4 | pandas 5 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. _readme: 2 | .. include:: ../README.rst 3 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. _authors: 2 | .. include:: ../AUTHORS.rst 3 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changes: 2 | .. include:: ../CHANGELOG.rst 3 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | .. include:: ../LICENSE.txt 8 | -------------------------------------------------------------------------------- /example/result/test.Wm.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/test.Wm.tsv.gz -------------------------------------------------------------------------------- /example/result/test.Zm.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/test.Zm.tsv.gz -------------------------------------------------------------------------------- /example/result/test.Wvar.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/test.Wvar.tsv.gz -------------------------------------------------------------------------------- /example/result/test.Zvar.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/test.Zvar.tsv.gz -------------------------------------------------------------------------------- /example/data/n20_p1k.Zscore.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/data/n20_p1k.Zscore.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_old.Wm.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_old.Wm.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_old.Zm.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_old.Zm.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_test.Wm.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_test.Wm.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_test.Zm.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_test.Zm.tsv.gz -------------------------------------------------------------------------------- /example/result/test.factor.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/test.factor.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_old.Wvar.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_old.Wvar.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_old.Zvar.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_old.Zvar.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_test.Wvar.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_test.Wvar.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_test.Zvar.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_test.Zvar.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_old.factor.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_old.factor.tsv.gz -------------------------------------------------------------------------------- /example/result/demo_test.factor.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mancusolab/FactorGo/HEAD/example/result/demo_test.factor.tsv.gz -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit 2 | pymdown-extensions 3 | pytest-accept 4 | pytest-cov 5 | setuptools 6 | sphinx>=5 7 | sphinx_immaterial 8 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Contributors 3 | ============ 4 | 5 | * Eleanor Zhang 6 | * Steven Gazal 7 | * Nicholas Mancuso 8 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | 5 | Version 0.1 6 | =========== 7 | 8 | - Feature A added 9 | - FIX: nasty bug #1729 fixed 10 | - add your changes here! 11 | -------------------------------------------------------------------------------- /example/run_demo.sh: -------------------------------------------------------------------------------- 1 | python runFactorGo.old.py \ 2 | ./data/n20_p1k.Zscore.tsv.gz \ 3 | ./data/n20_p1k.SampleN.tsv \ 4 | -k 5 \ 5 | --scaledat \ 6 | -o ./result/demo_old\ 7 | -p cpu 8 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black 3 | known_first_party = factorgo 4 | known_jax = jax 5 | sections = FUTURE, STDLIB, THIRDPARTY, JAX, FIRSTPARTY, LOCALFOLDER 6 | combine_as_imports = true 7 | multi_line_output = 3 8 | skip=docs 9 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Requirements file for ReadTheDocs, check .readthedocs.yml. 2 | # To build the module reference correctly, make sure every external package 3 | # under `install_requires` in `setup.cfg` is also listed here! 4 | sphinx>=3.2.1 5 | # sphinx_rtd_theme 6 | -------------------------------------------------------------------------------- /example/data/n20_p1k.SampleN.tsv: -------------------------------------------------------------------------------- 1 | N 2 | 58890 3 | 420473 4 | 418817 5 | 420531 6 | 35764 7 | 213747 8 | 420473 9 | 58890 10 | 420531 11 | 84306 12 | 411183 13 | 411741 14 | 420473 15 | 420531 16 | 420531 17 | 58890 18 | 385675 19 | 405376 20 | 420473 21 | 85877 22 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dummy conftest.py for vibrate. 3 | 4 | If you don't know what this is for, just leave it empty. 5 | Read more about conftest.py under: 6 | - https://docs.pytest.org/en/stable/fixture.html 7 | - https://docs.pytest.org/en/stable/writing_plugins.html 8 | """ 9 | 10 | # import pytest 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! 3 | requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5", "wheel"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools_scm] 7 | # See configuration details in https://github.com/pypa/setuptools_scm 8 | version_scheme = "no-guess-dev" 9 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy.testing as nptest 2 | import pandas as pd 3 | 4 | 5 | def assert_array_eq(target_path, queue_path, rtol=1e-5): 6 | target = pd.read_csv(target_path, delimiter="\t", header=None).to_numpy() 7 | queue = pd.read_csv(queue_path, delimiter="\t", header=None).to_numpy() 8 | 9 | nptest.assert_allclose(target, queue, rtol=rtol) 10 | -------------------------------------------------------------------------------- /.github/workflows/package_install.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | - name: Set up Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.x' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Build documentation in the docs/ directory with Sphinx 8 | sphinx: 9 | configuration: docs/conf.py 10 | 11 | # Build documentation with MkDocs 12 | #mkdocs: 13 | # configuration: mkdocs.yml 14 | 15 | # Optionally build your docs in additional formats such as PDF 16 | formats: 17 | - pdf 18 | 19 | python: 20 | version: 3.8 21 | install: 22 | - requirements: docs/requirements.txt 23 | - {path: ., method: pip} 24 | -------------------------------------------------------------------------------- /src/factorgo/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | if sys.version_info[:2] >= (3, 8): 4 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 5 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 6 | else: 7 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 8 | 9 | try: 10 | # Change here if project is renamed and does not equal the package name 11 | dist_name = __name__ 12 | __version__ = version(dist_name) 13 | except PackageNotFoundError: # pragma: no cover 14 | __version__ = "unknown" 15 | finally: 16 | del version, PackageNotFoundError 17 | 18 | from . import cli, infer, io, util 19 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = FactorGo 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /example/test_skeleton.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from factorgo.skeleton import fib, main 4 | 5 | __author__ = "Nicholas Mancuso" 6 | __copyright__ = "Nicholas Mancuso" 7 | __license__ = "MIT" 8 | 9 | 10 | def test_fib(): 11 | """API Tests""" 12 | assert fib(1) == 1 13 | assert fib(2) == 1 14 | assert fib(7) == 13 15 | with pytest.raises(AssertionError): 16 | fib(-10) 17 | 18 | 19 | def test_main(capsys): 20 | """CLI Tests""" 21 | # capsys is a pytest fixture that allows asserts agains stdout/stderr 22 | # https://docs.pytest.org/en/stable/capture.html 23 | main(["7"]) 24 | captured = capsys.readouterr() 25 | assert "The 7-th Fibonacci number is 13" in captured.out 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for factorgo. 3 | Use setup.cfg to configure your project. 4 | 5 | This file was generated with PyScaffold 4.1.1. 6 | PyScaffold helps you to put up the scaffold of your new Python project. 7 | Learn more under: https://pyscaffold.org/ 8 | """ 9 | from setuptools import setup 10 | 11 | if __name__ == "__main__": 12 | try: 13 | setup(use_scm_version={"version_scheme": "no-guess-dev"}) 14 | except: # noqa 15 | print( 16 | "\n\nAn error occurred while building the project, " 17 | "please ensure you have the most updated version of setuptools, " 18 | "setuptools_scm and wheel with:\n" 19 | " pip install -U setuptools setuptools_scm wheel\n\n" 20 | ) 21 | raise 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !.isort.cfg 7 | !setup.cfg 8 | *.orig 9 | *.log 10 | *.pot 11 | __pycache__/* 12 | .cache/* 13 | .*.swp 14 | */.ipynb_checkpoints/* 15 | .DS_Store 16 | 17 | # Project files 18 | .ropeproject 19 | .project 20 | .pydevproject 21 | .settings 22 | .idea 23 | .vscode 24 | archive/ 25 | tags 26 | 27 | # Package files 28 | *.egg 29 | *.eggs/ 30 | .installed.cfg 31 | *.egg-info 32 | 33 | # Unittest and coverage 34 | htmlcov/* 35 | .coverage 36 | .coverage.* 37 | .tox 38 | junit*.xml 39 | coverage.xml 40 | .pytest_cache/ 41 | 42 | # Build and docs folder/files 43 | build/* 44 | dist/* 45 | sdist/* 46 | docs/api/* 47 | docs/_rst/* 48 | docs/_build/* 49 | cover/* 50 | MANIFEST 51 | 52 | # Per-project virtualenvs 53 | .venv*/ 54 | .conda*/ 55 | -------------------------------------------------------------------------------- /tests/test_factorgo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | from utils import assert_array_eq 5 | 6 | 7 | def test_cli_tool_writes_file(): 8 | target_file = "../example/result/demo_old.factor.tsv.gz" 9 | output_file = "../example/result/demo_test.factor.tsv.gz" 10 | 11 | subprocess.run( 12 | [ 13 | "factorgo", 14 | "./example/data/n20_p1k.Zscore.tsv.gz", 15 | "./example/data/n20_p1k.SampleN.tsv", 16 | "-k", 17 | "5", 18 | "--scale", 19 | "-o", 20 | "./example/result/demo_test", 21 | ], 22 | capture_output=False, 23 | text=False, 24 | ) 25 | 26 | assert os.path.isfile(output_file) 27 | # TODO: not sure why raise error, not see difference 28 | assert assert_array_eq(target_file, output_file) 29 | -------------------------------------------------------------------------------- /src/factorgo/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple 3 | 4 | import jax 5 | 6 | from factorgo import infer 7 | 8 | 9 | def update_x64(state: bool) -> None: 10 | jax.config.update("jax_enable_x64", state) 11 | return 12 | 13 | 14 | def set_platform(platform=None) -> None: 15 | """ 16 | Changes platform to CPU, GPU, or TPU. This utility only takes 17 | effect at the beginning of your program. 18 | 19 | :param str platform: either 'cpu', 'gpu', or 'tpu'. 20 | """ 21 | if platform is None: 22 | platform = os.getenv("JAX_PLATFORM_NAME", "cpu") 23 | jax.config.update("jax_platform_name", platform) 24 | return 25 | 26 | 27 | def set_hyper(hyper: Tuple) -> None: 28 | """ 29 | set user-specified hyper parameters in priors 30 | """ 31 | infer.HyperParams.halpha_a = float(hyper[0]) 32 | infer.HyperParams.halpha_b = float(hyper[1]) 33 | infer.HyperParams.htau_a = float(hyper[2]) 34 | infer.HyperParams.htau_b = float(hyper[3]) 35 | infer.HyperParams.hbeta = float(hyper[4]) 36 | return 37 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2022 Nicholas Mancuso 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | exclude: '^(docs/conf.py|tests/testdata/.*|example/.*)' 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.0.1 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: check-added-large-files 9 | - id: check-ast 10 | - id: check-json 11 | - id: check-merge-conflict 12 | - id: check-xml 13 | - id: check-yaml 14 | - id: debug-statements 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: mixed-line-ending 18 | args: ['--fix=auto'] # replace 'auto' with 'lf' to enforce Linux/Mac line endings or 'crlf' for Windows 19 | 20 | - repo: https://github.com/pycqa/isort 21 | rev: 5.12.0 22 | hooks: 23 | - id: isort 24 | 25 | - repo: https://github.com/psf/black 26 | rev: 21.11b1 27 | hooks: 28 | - id: black 29 | language_version: python3 30 | additional_dependencies: ['click==8.0.4'] 31 | 32 | - repo: https://github.com/PyCQA/flake8 33 | rev: 4.0.1 34 | hooks: 35 | - id: flake8 36 | ## You can add flake8 plugins via `additional_dependencies`: 37 | # additional_dependencies: [flake8-bugbear] 38 | 39 | - repo: https://github.com/pre-commit/mirrors-mypy 40 | rev: 'v0.971' # Use the sha / tag you want to point at 41 | hooks: 42 | - id: mypy 43 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | AUTODOCDIR = api 11 | 12 | # User-friendly check for sphinx-build 13 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 14 | $(error "The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from https://sphinx-doc.org/") 15 | endif 16 | 17 | .PHONY: help clean Makefile 18 | 19 | # Put it first so that "make" without argument is like "make help". 20 | help: 21 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | 23 | clean: 24 | rm -rf $(BUILDDIR)/* $(AUTODOCDIR) 25 | 26 | # Catch-all target: route all unknown targets to Sphinx using the new 27 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 28 | %: Makefile 29 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 30 | -------------------------------------------------------------------------------- /src/factorgo/io.py: -------------------------------------------------------------------------------- 1 | from logging import Logger 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from jax import numpy as jnp 7 | 8 | 9 | def read_data(z_path: str, N_path: str, log: Logger, scale: bool = True): 10 | """ 11 | input z score summary stats: 12 | headers are ["snp", "trait1", "trait2", ..., "traitn"] 13 | 14 | input sample size file: 15 | one column of sample size (with header) which has the same order as above 16 | """ 17 | 18 | # Read dataset (rows are SNPs) 19 | df_z = pd.read_csv(z_path, delimiter="\t", header=0) 20 | snp_col = df_z.columns[0] 21 | 22 | # drop the first column (axis = 1) and convert to nxp 23 | df_z.drop(labels=[snp_col], axis=1, inplace=True) 24 | df_z = df_z.astype("float").T 25 | 26 | if scale: 27 | df_z = df_z.subtract(df_z.mean()) 28 | df_z = df_z.divide(df_z.std()) 29 | log.info("Scale SNPs to mean zero and sd 1") 30 | 31 | # convert to numpy/jax device-array (n,p) 32 | df_z = jnp.array(df_z) 33 | 34 | # read sample size file and convert str into numerics, convert to nxp matrix 35 | df_N = pd.read_csv(N_path, delimiter="\t", header=0) 36 | df_N = df_N.astype("float") 37 | 38 | # convert sampleN (a file with one column and header)to arrays 39 | N_col = df_N.columns[0] 40 | sampleN = df_N[N_col].values 41 | sampleN_sqrt = jnp.sqrt(sampleN) 42 | 43 | return df_z, sampleN, sampleN_sqrt 44 | 45 | 46 | def write_results(output, f_info, ordered_Z_m, ordered_W_m, W_var, Z_var, f_order): 47 | n, k = ordered_Z_m.shape 48 | 49 | ordered_W_var = jnp.diagonal(W_var)[f_order] 50 | 51 | Z_var_diag = np.zeros((n, k)) 52 | for i in range(n): 53 | Z_var_diag[i] = jnp.diagonal(Z_var[i]) 54 | ordered_Z_var_diag = Z_var_diag[:, f_order] 55 | 56 | np.savetxt(f"{output}.factor.tsv.gz", f_info, fmt="%s", delimiter="\t") 57 | np.savetxt(f"{output}.Zm.tsv.gz", ordered_Z_m.real, fmt="%s", delimiter="\t") 58 | np.savetxt(f"{output}.Wm.tsv.gz", ordered_W_m.real, fmt="%s", delimiter="\t") 59 | np.savetxt( 60 | f"{output}.Zvar.tsv.gz", ordered_Z_var_diag.real, fmt="%s", delimiter="\t" 61 | ) 62 | np.savetxt(f"{output}.Wvar.tsv.gz", ordered_W_var.real, fmt="%s", delimiter="\t") 63 | return 64 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ======= 2 | FactorGo 3 | ======= 4 | 5 | This is the documentation of **FactorGo**. 6 | 7 | .. note:: 8 | 9 | This is the main page of your project's `Sphinx`_ documentation. 10 | It is formatted in `reStructuredText`_. Add additional pages 11 | by creating rst-files in ``docs`` and adding them to the `toctree`_ below. 12 | Use then `references`_ in order to link them from this page, e.g. 13 | :ref:`authors` and :ref:`changes`. 14 | 15 | It is also possible to refer to the documentation of other Python packages 16 | with the `Python domain syntax`_. By default you can reference the 17 | documentation of `Sphinx`_, `Python`_, `NumPy`_, `SciPy`_, `matplotlib`_, 18 | `Pandas`_, `Scikit-Learn`_. You can add more by extending the 19 | ``intersphinx_mapping`` in your Sphinx's ``conf.py``. 20 | 21 | The pretty useful extension `autodoc`_ is activated by default and lets 22 | you include documentation from docstrings. Docstrings can be written in 23 | `Google style`_ (recommended!), `NumPy style`_ and `classical style`_. 24 | 25 | 26 | Contents 27 | ======== 28 | 29 | .. toctree:: 30 | :maxdepth: 2 31 | 32 | Overview 33 | Contributions & Help 34 | License 35 | Authors 36 | Changelog 37 | Module Reference 38 | 39 | 40 | Indices and tables 41 | ================== 42 | 43 | * :ref:`genindex` 44 | * :ref:`modindex` 45 | * :ref:`search` 46 | 47 | .. _toctree: https://www.sphinx-doc.org/en/master/usage/restructuredtext/directives.html 48 | .. _reStructuredText: https://www.sphinx-doc.org/en/master/usage/restructuredtext/basics.html 49 | .. _references: https://www.sphinx-doc.org/en/stable/markup/inline.html 50 | .. _Python domain syntax: https://www.sphinx-doc.org/en/master/usage/restructuredtext/domains.html#the-python-domain 51 | .. _Sphinx: https://www.sphinx-doc.org/ 52 | .. _Python: https://docs.python.org/ 53 | .. _Numpy: https://numpy.org/doc/stable 54 | .. _SciPy: https://docs.scipy.org/doc/scipy/reference/ 55 | .. _matplotlib: https://matplotlib.org/contents.html# 56 | .. _Pandas: https://pandas.pydata.org/pandas-docs/stable 57 | .. _Scikit-Learn: https://scikit-learn.org/stable 58 | .. _autodoc: https://www.sphinx-doc.org/en/master/ext/autodoc.html 59 | .. _Google style: https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings 60 | .. _NumPy style: https://numpydoc.readthedocs.io/en/latest/format.html 61 | .. _classical style: https://www.sphinx-doc.org/en/master/domains.html#info-field-lists 62 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox configuration file 2 | # Read more under https://tox.wiki/ 3 | # THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 4 | 5 | [tox] 6 | minversion = 3.24 7 | envlist = default 8 | isolated_build = True 9 | 10 | 11 | [testenv] 12 | description = Invoke pytest to run automated tests 13 | setenv = 14 | TOXINIDIR = {toxinidir} 15 | passenv = 16 | HOME 17 | SETUPTOOLS_* 18 | extras = 19 | testing 20 | deps = 21 | -r {toxinidir}/requirements.txt 22 | -r {toxinidir}/requirements_dev.txt 23 | commands = 24 | pytest {posargs} 25 | 26 | 27 | # # To run `tox -e lint` you need to make sure you have a 28 | # # `.pre-commit-config.yaml` file. See https://pre-commit.com 29 | # [testenv:lint] 30 | # description = Perform static analysis and style checks 31 | # skip_install = True 32 | # deps = pre-commit 33 | # passenv = 34 | # HOMEPATH 35 | # PROGRAMDATA 36 | # SETUPTOOLS_* 37 | # commands = 38 | # pre-commit run --all-files {posargs:--show-diff-on-failure} 39 | 40 | 41 | [testenv:{build,clean}] 42 | description = 43 | build: Build the package in isolation according to PEP517, see https://github.com/pypa/build 44 | clean: Remove old distribution files and temporary build artifacts (./build and ./dist) 45 | # https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it 46 | skip_install = True 47 | changedir = {toxinidir} 48 | deps = 49 | build: build[virtualenv] 50 | passenv = 51 | SETUPTOOLS_* 52 | commands = 53 | clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' 54 | clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' 55 | build: python -m build {posargs} 56 | # By default, both `sdist` and `wheel` are built. If your sdist is too big or you don't want 57 | # to make it available, consider running: `tox -e build -- --wheel` 58 | 59 | 60 | [testenv:{docs,doctests,linkcheck}] 61 | description = 62 | docs: Invoke sphinx-build to build the docs 63 | doctests: Invoke sphinx-build to run doctests 64 | linkcheck: Check for broken links in the documentation 65 | passenv = 66 | SETUPTOOLS_* 67 | setenv = 68 | DOCSDIR = {toxinidir}/docs 69 | BUILDDIR = {toxinidir}/docs/_build 70 | docs: BUILD = html 71 | doctests: BUILD = doctest 72 | linkcheck: BUILD = linkcheck 73 | deps = 74 | -r {toxinidir}/requirements.txt 75 | -r {toxinidir}/requirements_dev.txt 76 | commands = 77 | sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} 78 | 79 | 80 | [testenv:publish] 81 | description = 82 | Publish the package you have been developing to a package index server. 83 | By default, it uses testpypi. If you really want to publish your package 84 | to be publicly accessible in PyPI, use the `-- --repository pypi` option. 85 | skip_install = True 86 | changedir = {toxinidir} 87 | passenv = 88 | # See: https://twine.readthedocs.io/en/latest/ 89 | TWINE_USERNAME 90 | TWINE_PASSWORD 91 | TWINE_REPOSITORY 92 | TWINE_REPOSITORY_URL 93 | deps = twine 94 | commands = 95 | python -m twine check dist/* 96 | python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* 97 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # This file is used to configure your project. 2 | # Read more about the various options under: 3 | # https://setuptools.pypa.io/en/latest/userguide/declarative_config.html 4 | # https://setuptools.pypa.io/en/latest/references/keywords.html 5 | 6 | [metadata] 7 | name = FactorGo 8 | description = Add a short description here! 9 | author = TMP 10 | author_email = TMP 11 | license = MIT 12 | license_files = LICENSE.txt 13 | long_description = file: README.rst 14 | long_description_content_type = text/x-rst; charset=UTF-8 15 | url = https://github.com/pyscaffold/pyscaffold/ 16 | # Add here related links, for example: 17 | project_urls = 18 | Documentation = https://pyscaffold.org/ 19 | # Source = https://github.com/pyscaffold/pyscaffold/ 20 | # Changelog = https://pyscaffold.org/en/latest/changelog.html 21 | # Tracker = https://github.com/pyscaffold/pyscaffold/issues 22 | # Conda-Forge = https://anaconda.org/conda-forge/pyscaffold 23 | # Download = https://pypi.org/project/PyScaffold/#files 24 | # Twitter = https://twitter.com/PyScaffold 25 | 26 | # Change if running only on Windows, Mac or Linux (comma-separated) 27 | platforms = any 28 | 29 | # Add here all kinds of additional classifiers as defined under 30 | # https://pypi.org/classifiers/ 31 | classifiers = 32 | Development Status :: 4 - Beta 33 | Programming Language :: Python 34 | 35 | 36 | [options] 37 | zip_safe = False 38 | packages = find_namespace: 39 | include_package_data = True 40 | package_dir = 41 | =src 42 | 43 | # Require a min/specific Python version (comma-separated conditions) 44 | # python_requires = >=3.8 45 | 46 | # Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0. 47 | # Version specifiers like >=2.2,<3.0 avoid problems due to API changes in 48 | # new major versions. This works if the required packages follow Semantic Versioning. 49 | # For more information, check out https://semver.org/. 50 | install_requires = 51 | importlib-metadata; python_version<"3.11" 52 | jax 53 | jaxlib 54 | pandas 55 | numpy 56 | 57 | 58 | [options.packages.find] 59 | where = src 60 | exclude = 61 | tests 62 | 63 | [options.extras_require] 64 | # Add here additional requirements for extra features, to install with: 65 | # `pip install factorgo[PDF]` like: 66 | # PDF = ReportLab; RXP 67 | 68 | # Add here test requirements (semicolon/line-separated) 69 | testing = 70 | setuptools 71 | pytest 72 | pytest-cov 73 | 74 | [options.entry_points] 75 | # Add here console scripts like: 76 | console_scripts = 77 | factorgo = factorgo.cli:run_cli 78 | 79 | [tool:pytest] 80 | # Specify command line options as you would do when invoking pytest directly. 81 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 82 | # in order to write a coverage file that can be read by Jenkins. 83 | # CAUTION: --cov flags may prohibit setting breakpoints while debugging. 84 | # Comment those flags to avoid this pytest issue. 85 | addopts = 86 | --cov factorgo --cov-report term-missing 87 | --verbose 88 | norecursedirs = 89 | dist 90 | build 91 | .tox 92 | testpaths = tests 93 | # Use pytest markers to select/deselect specific tests 94 | # markers = 95 | # slow: mark tests as slow (deselect with '-m "not slow"') 96 | # system: mark end-to-end system tests 97 | 98 | [devpi:upload] 99 | # Options for the devpi: PyPI server and packaging tool 100 | # VCS export must be deactivated since we are using setuptools-scm 101 | no_vcs = 1 102 | formats = bdist_wheel 103 | 104 | [flake8] 105 | # Some sane defaults for the code style checker flake8 106 | max_line_length = 120 107 | extend_ignore = E203, W503 108 | # ^ Black-compatible 109 | # E203 and W503 have edge cases handled by black 110 | exclude = 111 | .tox 112 | build 113 | dist 114 | .eggs 115 | docs/conf.py 116 | 117 | [pyscaffold] 118 | # PyScaffold's parameters when the project was created. 119 | # This will be used when updating. Do not change! 120 | version = 4.1.1 121 | package = factorgo 122 | -------------------------------------------------------------------------------- /src/factorgo/cli.py: -------------------------------------------------------------------------------- 1 | import argparse as ap 2 | import logging 3 | import sys 4 | 5 | import jax.random as rdm 6 | 7 | from factorgo import infer, io, util 8 | 9 | 10 | def get_logger(name, path=None): 11 | """get logger for factorgo progress""" 12 | logger = logging.getLogger(name) 13 | if not logger.handlers: 14 | # Prevent logging from propagating to the root logger 15 | logger.propagate = 0 16 | console = logging.StreamHandler() 17 | logger.addHandler(console) 18 | 19 | log_format = "[%(asctime)s - %(levelname)s] %(message)s" 20 | date_format = "%Y-%m-%d %H:%M:%S" 21 | formatter = logging.Formatter(fmt=log_format, datefmt=date_format) 22 | console.setFormatter(formatter) 23 | 24 | if path is not None: 25 | disk_log_stream = open("{}.log".format(path), "w") 26 | disk_handler = logging.StreamHandler(disk_log_stream) 27 | logger.addHandler(disk_handler) 28 | disk_handler.setFormatter(formatter) 29 | 30 | return logger 31 | 32 | 33 | def _main(args): 34 | argp = ap.ArgumentParser(description="") # create an instance 35 | argp.add_argument( 36 | "Zscore_path", 37 | help="z score file must have this format: [snp, trait1, trait2, ..., traitn], " 38 | "with headers (any names)", 39 | ) 40 | argp.add_argument( 41 | "N_path", 42 | help="Sample N file must be one column with same order as Z score file, " 43 | "with header (any name)", 44 | ) 45 | argp.add_argument( 46 | "-k", 47 | type=int, 48 | help="Number of latent factors to estimate, " 49 | "maximum number should be <= min(n, p), " 50 | "where n is number of traits, p is number of variants", 51 | ) 52 | argp.add_argument( 53 | "--elbo-tol", 54 | default=1e-3, 55 | type=float, 56 | help="Tolerance for change in ELBO to halt inference,default=1e-3", 57 | ) 58 | argp.add_argument( 59 | "--hyper", 60 | default=None, 61 | nargs="+", 62 | type=float, 63 | help="Input hyperparameter in this order: alpha_a, alpha_b, tau_a, tau_b and mu_beta. " 64 | "Example: --hyper 1e-3, 1e-3, 1e-5, 1e-5, 1e-5", 65 | ) 66 | argp.add_argument( 67 | "--max-iter", 68 | default=10000, 69 | type=int, 70 | help="Maximum number of iterations to learn parameters, default=10000", 71 | ) 72 | argp.add_argument( 73 | "--init-factor", 74 | choices=["random", "svd"], 75 | default="random", 76 | help="How to initialize the latent factors and weights", 77 | ) 78 | argp.add_argument( 79 | "--scale", 80 | action="store_true", 81 | default=False, 82 | help="Scale each SNPs effect across traits (Default=False)", 83 | ) 84 | argp.add_argument( 85 | "--rate", 86 | default=50, 87 | type=int, 88 | help="Rate of printing elbo info; default is printing per 50 iters", 89 | ) 90 | argp.add_argument( 91 | "-p", 92 | "--platform", 93 | choices=["cpu", "gpu", "tpu"], 94 | default="cpu", 95 | help="Change platform depending on hardware resource, default=cpu", 96 | ) 97 | argp.add_argument( 98 | "-s", 99 | "--seed", 100 | type=int, 101 | default=123456789, 102 | help="Seed for randomization, default=123456789", 103 | ) 104 | argp.add_argument( 105 | "-d", 106 | "--debug", 107 | action="store_true", 108 | default=False, 109 | help="Set logger to be debug mode, default=False", 110 | ) 111 | argp.add_argument( 112 | "-v", 113 | "--verbose", 114 | action="store_true", 115 | default=False, 116 | help="If this is true, set logger to be debug mode if debug=True. Default=False", 117 | ) 118 | argp.add_argument( 119 | "-o", 120 | "--output", 121 | type=str, 122 | default="factorgo", 123 | help="Prefix path for output, default=factorgo", 124 | ) 125 | 126 | args = argp.parse_args(args) 127 | 128 | log = get_logger(__name__, args.output) 129 | if args.verbose: 130 | log.setLevel(logging.DEBUG) 131 | else: 132 | log.setLevel(logging.INFO) 133 | 134 | # setup to use either CPU (default) or GPU 135 | util.set_platform(args.platform) 136 | 137 | # ensure 64bit precision (default use 32bit) 138 | util.update_x64(True) 139 | 140 | # init key (for jax), split into 2 chunk 141 | key = rdm.PRNGKey(args.seed) 142 | key, key_init = rdm.split(key, 2) 143 | 144 | Version = "1.0.0" 145 | 146 | log.info( 147 | f""" 148 | ############################################# 149 | 150 | Welcome to use FactorGo! 151 | Version: {Version} 152 | 153 | ############################################# 154 | """ 155 | ) 156 | 157 | log.info("Loading GWAS summary statistics and sample size.") 158 | B, sampleN, sampleN_sqrt = io.read_data( 159 | args.Zscore_path, args.N_path, log, args.scale 160 | ) 161 | log.info("Finished loading GWAS summary statistics and sample size.") 162 | 163 | n_studies, p_snps = B.shape 164 | log.info(f"Found N = {n_studies} studies, P = {p_snps} SNPs") 165 | 166 | # number of factors 167 | k = args.k 168 | log.info(f"User set K = {k} latent factors.") 169 | 170 | # set options for stopping rule 171 | options = infer.Options(args.elbo_tol, args.max_iter) 172 | 173 | # set 5 hyper-parameters: otherwise use default 1e-5 174 | if args.hyper is not None: 175 | util.set_hyper(args.hyper) 176 | 177 | log.info( 178 | f"""set hyper parameters 179 | halpha_a: {infer.HyperParams.halpha_a}, 180 | halpha_b: {infer.HyperParams.halpha_b}, 181 | htau_a: {infer.HyperParams.htau_a}, 182 | htau_b: {infer.HyperParams.htau_b}, 183 | hbeta: {infer.HyperParams.hbeta} 184 | """ 185 | ) 186 | 187 | W_m, W_var, Z_m, Z_var, f_info, f_order, ordered_W_m, ordered_Z_m = infer.fit( 188 | B, args, k, key_init, log, options, sampleN, sampleN_sqrt 189 | ) 190 | 191 | log.info("Writing results.") 192 | io.write_results( 193 | args.output, f_info, ordered_Z_m, ordered_W_m, W_var, Z_var, f_order 194 | ) 195 | log.info("Finished. Goodbye.") 196 | 197 | return 0 198 | 199 | 200 | def run_cli(): 201 | return _main(sys.argv[1:]) 202 | 203 | 204 | if __name__ == "__main__": 205 | sys.exit(_main(sys.argv[1:])) 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI-Server](https://img.shields.io/pypi/v/factorgo.svg)](https://pypi.org/project/factorgo/) 2 | [![Github](https://img.shields.io/github/stars/mancusolab/factorgo?style=social)](https://github.com/mancusolab/factorgo) 3 | [![License](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | 5 | # FactorGo 6 | 7 | ``FactorGo`` is a scalable variational factor analysis model that learns pleiotropic factors using GWAS summary statistics! 8 | 9 | We present **Factor** analysis model in **G**enetic ass**O**ciation (FactorGo) to learn latent 10 | pleiotropic factors using GWAS summary statistics. Our model is implemented using `Just-in-time` (JIT) 11 | via [JAX](https://github.com/google/jax) in python, which generates and compiles heavily optimized 12 | C++ code in real time and operates seamlessly on CPU, GPU or TPU. FactorGo is a command line tool and 13 | please see example below and full documentation. 14 | 15 | For pubished paper, please see: 16 | 17 | Zhang, Z., Jung, J., Kim, A., Suboc, N., Gazal, S., and Mancuso, N. (2023). A scalable approach to characterize pleiotropy across thousands of human diseases and complex traits using GWAS summary statistics. _Am. J. Hum. Genet._ 110, 1863–1874. 18 | (https://www.cell.com/ajhg/abstract/S0002-9297(23)00353-1) 19 | 20 | We are currently working on more detailed documentations. Feel free to contact me (zzhang39@usc.edu) if you need help on running our tool and further analysis. I am happy to schedule zoom call if needed. 21 | 22 | [**Installation**](#installation) 23 | | [**Example**](#example) 24 | | [**Notes**](#notes) 25 | | [**Support**](#support) 26 | | [**Other Software**](#other-software) 27 | 28 | ## FactorGo model 29 | 30 | FactorGo assumes the true genetic effect can be decomposed into latent pleiotropic factors. 31 | Briefly, we model test statistics at $p$ independent variants from the ith GWAS $Z_i \\approx \\sqrt{N}_i \\hat{\\beta}_i$ as a 32 | linear combination of $k$ shared latent variant loadings $L \\in R^{p \\times k}$ with trait-specific factor scores $f_i \\in R^{k \\times 1}$ as 33 | 34 | $$Z_i = \\sqrt{N}_i \\beta_i + \\epsilon_i = \\sqrt{N}_i (L f_i + \\mu) + \\epsilon_i $$ 35 | 36 | where $N_i$ is the sample size for the $i^{th}$ GWAS , $\\mu$ is the intercept and $\\epsilon_i \\sim N(0, \\tau^{-1}I_p)$ reflects residual 37 | heterogeneity in statistical power across studies with precision scalar . 38 | Given $Z = \\{Z_i\\}^n_{i=1}$, and model parameters $L$, $F$, $\\mu$, $\\tau$, we can compute the likelihood as 39 | 40 | $$\\mathcal{L}(L, F, \\mu, \\tau | Z) = \\prod_i \\mathcal{N}_p ( \\sqrt{N_i} (L f_i + \\mu), \\tau^{-1} I_p)$$ 41 | 42 | To model our uncertainty in $L$, $F$, $\\mu$, we take a full Bayesian approach in the lower dimension latent space 43 | similar to a Bayesian PCA model [1]_ as, 44 | 45 | $$\Pr(F) = \\prod_{i=1}^{n} \\mathcal{N}_k (f_i | 0, I_k)$$ 46 | 47 | $$\Pr(L | \\alpha) = \\prod_{j=1}^{p} \\mathcal{N}_k (l^j | 0, diag(\\alpha^{-1}))$$ 48 | 49 | $$\Pr(\\mu) = \\mathcal{N}_p (\\mu | 0, \\phi^{-1} I_p)$$ 50 | 51 | where $\\alpha \\in R^{k \\times 1}_{>0} (\\phi > 0)$ controls the prior precision for variant loadings (intercept). To avoid overfitting, 52 | and “shut off” uninformative factors when $k$ is misspecified, we use automatic relevance determination (ARD) [1]_ 53 | and place a prior over $\\alpha$ as 54 | 55 | $$\Pr(\\alpha | \\alpha_a, \\alpha_b) = \\prod_{q=1}^{k} G(\\alpha_q | \\alpha_a, \\alpha_b)$$ 56 | 57 | $$\Pr(\\tau | \\tau_a, \\tau_b) = G(\\tau | \\tau_a, \\tau_b)$$ 58 | 59 | Lastly, we place a prior over the shared residual variance across GWAS studies as $\\tau \\sim G(a , b)$. 60 | We impose broad priors by setting hyperparameters $\\phi = a_k = b_k= a_{\\tau} = b_{\\tau} = 10^{-5}$. 61 | 62 | ## Installation 63 | 64 | We recommend first create a conda environment and have `pip` installed. 65 | ```bash 66 | # download use http address 67 | git clone https://github.com/mancusolab/FactorGo.git 68 | # or use ssh agent 69 | git clone git@github.com:mancusolab/FactorGo.git 70 | 71 | cd factorgo 72 | pip install . 73 | ``` 74 | 75 | ## Example 76 | For illustration, we use example data stored in `/example/data`, 77 | including Z score summary statistics file and sample size file. 78 | 79 | To run ``factorgo`` command line tool, we specify the following input files and flags: 80 | 81 | * GWAS Zscore file: n20_p1k.Zscore.tsv.gz 82 | * Sample size file: n20_p1k.SampleN.tsv 83 | * -k 5: estimate 5 latent factors 84 | * --scale: the snp columns of Zscore matrix is center and standardized 85 | * -o: output directory and prefix 86 | 87 | For all available flags, please use ``factorgo -h``. 88 | 89 | ```bash 90 | factorgo \ 91 | ./example/data/n20_p1k.Zscore.tsv.gz \ 92 | ./example/data/n20_p1k.SampleN.tsv \ 93 | -k 5 \ 94 | --scale \ 95 | -o ./example/result/demo_test 96 | ``` 97 | 98 | The output contains five result files: 99 | 100 | 1. demo_test.Wm.tsv.gz: posterior mean of loading matrix W (pxk) 101 | 2. demo_test.Zm.tsv.gz: posterior mean of factor score Z (nxk) 102 | 3. demo_test.Wvar.tsv.gz: posterior variance of loading matrix W (kx1) 103 | 4. demo_test.Zvar.tsv.gz: posterior variance of factor score Z (nxk) 104 | 5. demo_test.factor.tsv.gz: contains the following three columns 105 | 106 | | a) factor index (ordered by R2), 107 | | b) posterior mean of ARD precision parameters, 108 | | c) variance explained by each factor (R2) 109 | 110 | ## Notes 111 | 112 | The default computation device for ``factorgo`` is CPU. To switch to GPU device, you can specify the platform (cpu/gpu/tpu) using the flag `-p gpu` 113 | for example: 114 | 115 | ```bash 116 | factorgo \ 117 | ./example/data/n20_p1k.Zscore.tsv.gz \ 118 | ./example/data/n20_p1k.SampleN.tsv \ 119 | -k 5 \ 120 | --scale \ 121 | -p gpu \ # use gpu device 122 | -o ./example/result/demo_test 123 | ``` 124 | 125 | ``factorgo`` uses [JAX](https://github.com/google/jax) with [Just In Time](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html) compilation to achieve high-speed computation. 126 | However, there are some [issues](https://github.com/google/jax/issues/5501) for JAX with Mac M1 chip. 127 | To solve this, users need to initiate conda using [miniforge](https://github.com/conda-forge/miniforge), and then install ``factorgo`` using ``pip`` in the desired environment. 128 | 129 | ## References 130 | 131 | [1] Bishop, C.M. (1999). Variational principal components. 509–514. 132 | 133 | ## Support 134 | 135 | Please report any bugs or feature requests in the [Issue Tracker](https://github.com/mancusolab/FactorGo/issues>). 136 | If you have any questions or comments please contact zzhang39@usc.edu and/or nmancuso@usc.edu. 137 | 138 | 139 | ## Other Softwares 140 | 141 | Feel free to use other software developed by [Mancuso 142 | Lab](https://www.mancusolab.com/): 143 | 144 | - [SuShiE](https://github.com/mancusolab/sushie): a Bayesian 145 | fine-mapping framework for molecular QTL data across multiple 146 | ancestries. 147 | - [MA-FOCUS](https://github.com/mancusolab/ma-focus): a Bayesian 148 | fine-mapping framework using 149 | [TWAS](https://www.nature.com/articles/ng.3506) statistics across 150 | multiple ancestries to identify the causal genes for complex traits. 151 | - [SuSiE-PCA](https://github.com/mancusolab/susiepca): a scalable 152 | Bayesian variable selection technique for sparse principal component 153 | analysis 154 | - [twas_sim](https://github.com/mancusolab/twas_sim): a Python 155 | software to simulate [TWAS](https://www.nature.com/articles/ng.3506) 156 | statistics. 157 | - [HAMSTA](https://github.com/tszfungc/hamsta): a Python software to 158 | estimate heritability explained by local ancestry data from 159 | admixture mapping summary statistics. 160 | 161 | ## Note 162 | 163 | This project has been set up using PyScaffold 4.1.1. For details and usage 164 | information on PyScaffold see https://pyscaffold.org/. 165 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # This file is execfile()d with the current directory set to its containing dir. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | # 7 | # All configuration values have a default; values that are commented out 8 | # serve to show the default. 9 | 10 | import os 11 | import sys 12 | import shutil 13 | 14 | # -- Path setup -------------------------------------------------------------- 15 | 16 | __location__ = os.path.dirname(__file__) 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | sys.path.insert(0, os.path.join(__location__, "../src")) 22 | 23 | # -- Run sphinx-apidoc ------------------------------------------------------- 24 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 25 | # `sphinx-build -b html . _build/html`. See Issue: 26 | # https://github.com/readthedocs/readthedocs.org/issues/1139 27 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 28 | # setup.py install" in the RTD Advanced Settings. 29 | # Additionally it helps us to avoid running apidoc manually 30 | 31 | try: # for Sphinx >= 1.7 32 | from sphinx.ext import apidoc 33 | except ImportError: 34 | from sphinx import apidoc 35 | 36 | output_dir = os.path.join(__location__, "api") 37 | module_dir = os.path.join(__location__, "../src/factorgo") 38 | try: 39 | shutil.rmtree(output_dir) 40 | except FileNotFoundError: 41 | pass 42 | 43 | try: 44 | import sphinx 45 | 46 | cmd_line = f"sphinx-apidoc --implicit-namespaces -f -o {output_dir} {module_dir}" 47 | 48 | args = cmd_line.split(" ") 49 | if tuple(sphinx.__version__.split(".")) >= ("1", "7"): 50 | # This is a rudimentary parse_version to avoid external dependencies 51 | args = args[1:] 52 | 53 | apidoc.main(args) 54 | except Exception as e: 55 | print("Running `sphinx-apidoc` failed!\n{}".format(e)) 56 | 57 | # -- General configuration --------------------------------------------------- 58 | 59 | # If your documentation needs a minimal Sphinx version, state it here. 60 | # needs_sphinx = '1.0' 61 | 62 | # Add any Sphinx extension module names here, as strings. They can be extensions 63 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 64 | extensions = [ 65 | "sphinx.ext.autodoc", 66 | "sphinx.ext.intersphinx", 67 | "sphinx.ext.todo", 68 | "sphinx.ext.autosummary", 69 | "sphinx.ext.viewcode", 70 | "sphinx.ext.coverage", 71 | "sphinx.ext.doctest", 72 | "sphinx.ext.ifconfig", 73 | "sphinx.ext.mathjax", 74 | "sphinx.ext.napoleon", 75 | ] 76 | 77 | # Add any paths that contain templates here, relative to this directory. 78 | templates_path = ["_templates"] 79 | 80 | # The suffix of source filenames. 81 | source_suffix = ".rst" 82 | 83 | # The encoding of source files. 84 | # source_encoding = 'utf-8-sig' 85 | 86 | # The master toctree document. 87 | master_doc = "index" 88 | 89 | # General information about the project. 90 | project = "factorgo" 91 | copyright = "2022, Nicholas Mancuso" 92 | 93 | # The version info for the project you're documenting, acts as replacement for 94 | # |version| and |release|, also used in various other places throughout the 95 | # built documents. 96 | # 97 | # version: The short X.Y version. 98 | # release: The full version, including alpha/beta/rc tags. 99 | # If you don’t need the separation provided between version and release, 100 | # just set them both to the same value. 101 | try: 102 | from FactorGo import __version__ as version 103 | except ImportError: 104 | version = "" 105 | 106 | if not version or version.lower() == "unknown": 107 | version = os.getenv("READTHEDOCS_VERSION", "unknown") # automatically set by RTD 108 | 109 | release = version 110 | 111 | # The language for content autogenerated by Sphinx. Refer to documentation 112 | # for a list of supported languages. 113 | # language = None 114 | 115 | # There are two options for replacing |today|: either, you set today to some 116 | # non-false value, then it is used: 117 | # today = '' 118 | # Else, today_fmt is used as the format for a strftime call. 119 | # today_fmt = '%B %d, %Y' 120 | 121 | # List of patterns, relative to source directory, that match files and 122 | # directories to ignore when looking for source files. 123 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] 124 | 125 | # The reST default role (used for this markup: `text`) to use for all documents. 126 | # default_role = None 127 | 128 | # If true, '()' will be appended to :func: etc. cross-reference text. 129 | # add_function_parentheses = True 130 | 131 | # If true, the current module name will be prepended to all description 132 | # unit titles (such as .. function::). 133 | # add_module_names = True 134 | 135 | # If true, sectionauthor and moduleauthor directives will be shown in the 136 | # output. They are ignored by default. 137 | # show_authors = False 138 | 139 | # The name of the Pygments (syntax highlighting) style to use. 140 | pygments_style = "sphinx" 141 | 142 | # A list of ignored prefixes for module index sorting. 143 | # modindex_common_prefix = [] 144 | 145 | # If true, keep warnings as "system message" paragraphs in the built documents. 146 | # keep_warnings = False 147 | 148 | # If this is True, todo emits a warning for each TODO entries. The default is False. 149 | todo_emit_warnings = True 150 | 151 | 152 | # -- Options for HTML output ------------------------------------------------- 153 | 154 | # The theme to use for HTML and HTML Help pages. See the documentation for 155 | # a list of builtin themes. 156 | html_theme = "alabaster" 157 | 158 | # Theme options are theme-specific and customize the look and feel of a theme 159 | # further. For a list of options available for each theme, see the 160 | # documentation. 161 | html_theme_options = { 162 | "sidebar_width": "300px", 163 | "page_width": "1200px" 164 | } 165 | 166 | # Add any paths that contain custom themes here, relative to this directory. 167 | # html_theme_path = [] 168 | 169 | # The name for this set of Sphinx documents. If None, it defaults to 170 | # " v documentation". 171 | # html_title = None 172 | 173 | # A shorter title for the navigation bar. Default is the same as html_title. 174 | # html_short_title = None 175 | 176 | # The name of an image file (relative to this directory) to place at the top 177 | # of the sidebar. 178 | # html_logo = "" 179 | 180 | # The name of an image file (within the static path) to use as favicon of the 181 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 182 | # pixels large. 183 | # html_favicon = None 184 | 185 | # Add any paths that contain custom static files (such as style sheets) here, 186 | # relative to this directory. They are copied after the builtin static files, 187 | # so a file named "default.css" will overwrite the builtin "default.css". 188 | html_static_path = ["_static"] 189 | 190 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 191 | # using the given strftime format. 192 | # html_last_updated_fmt = '%b %d, %Y' 193 | 194 | # If true, SmartyPants will be used to convert quotes and dashes to 195 | # typographically correct entities. 196 | # html_use_smartypants = True 197 | 198 | # Custom sidebar templates, maps document names to template names. 199 | # html_sidebars = {} 200 | 201 | # Additional templates that should be rendered to pages, maps page names to 202 | # template names. 203 | # html_additional_pages = {} 204 | 205 | # If false, no module index is generated. 206 | # html_domain_indices = True 207 | 208 | # If false, no index is generated. 209 | # html_use_index = True 210 | 211 | # If true, the index is split into individual pages for each letter. 212 | # html_split_index = False 213 | 214 | # If true, links to the reST sources are added to the pages. 215 | # html_show_sourcelink = True 216 | 217 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 218 | # html_show_sphinx = True 219 | 220 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 221 | # html_show_copyright = True 222 | 223 | # If true, an OpenSearch description file will be output, and all pages will 224 | # contain a tag referring to it. The value of this option must be the 225 | # base URL from which the finished HTML is served. 226 | # html_use_opensearch = '' 227 | 228 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 229 | # html_file_suffix = None 230 | 231 | # Output file base name for HTML help builder. 232 | htmlhelp_basename = "factorgo-doc" 233 | 234 | 235 | # -- Options for LaTeX output ------------------------------------------------ 236 | 237 | latex_elements = { 238 | # The paper size ("letterpaper" or "a4paper"). 239 | # "papersize": "letterpaper", 240 | # The font size ("10pt", "11pt" or "12pt"). 241 | # "pointsize": "10pt", 242 | # Additional stuff for the LaTeX preamble. 243 | # "preamble": "", 244 | } 245 | 246 | # Grouping the document tree into LaTeX files. List of tuples 247 | # (source start file, target name, title, author, documentclass [howto/manual]). 248 | latex_documents = [ 249 | ("index", "user_guide.tex", "factorgo Documentation", "Nicholas Mancuso", "manual") 250 | ] 251 | 252 | # The name of an image file (relative to this directory) to place at the top of 253 | # the title page. 254 | # latex_logo = "" 255 | 256 | # For "manual" documents, if this is true, then toplevel headings are parts, 257 | # not chapters. 258 | # latex_use_parts = False 259 | 260 | # If true, show page references after internal links. 261 | # latex_show_pagerefs = False 262 | 263 | # If true, show URL addresses after external links. 264 | # latex_show_urls = False 265 | 266 | # Documents to append as an appendix to all manuals. 267 | # latex_appendices = [] 268 | 269 | # If false, no module index is generated. 270 | # latex_domain_indices = True 271 | 272 | # -- External mapping -------------------------------------------------------- 273 | python_version = ".".join(map(str, sys.version_info[0:2])) 274 | intersphinx_mapping = { 275 | "sphinx": ("https://www.sphinx-doc.org/en/master", None), 276 | "python": ("https://docs.python.org/" + python_version, None), 277 | "matplotlib": ("https://matplotlib.org", None), 278 | "numpy": ("https://numpy.org/doc/stable", None), 279 | "sklearn": ("https://scikit-learn.org/stable", None), 280 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), 281 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 282 | "setuptools": ("https://setuptools.readthedocs.io/en/stable/", None), 283 | "pyscaffold": ("https://pyscaffold.org/en/stable", None), 284 | } 285 | 286 | print(f"loading configurations for {project} {version} ...", file=sys.stderr) 287 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. todo:: THIS IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 2 | 3 | The document assumes you are using a source repository service that promotes a 4 | contribution model similar to `GitHub's fork and pull request workflow`_. 5 | While this is true for the majority of services (like GitHub, GitLab, 6 | BitBucket), it might not be the case for private repositories (e.g., when 7 | using Gerrit). 8 | 9 | Also notice that the code examples might refer to GitHub URLs or the text 10 | might use GitHub specific terminology (e.g., *Pull Request* instead of *Merge 11 | Request*). 12 | 13 | Please make sure to check the document having these assumptions in mind 14 | and update things accordingly. 15 | 16 | .. todo:: Provide the correct links/replacements at the bottom of the document. 17 | 18 | .. todo:: You might want to have a look on `PyScaffold's contributor's guide`_, 19 | 20 | especially if your project is open source. The text should be very similar to 21 | this template, but there are a few extra contents that you might decide to 22 | also include, like mentioning labels of your issue tracker or automated 23 | releases. 24 | 25 | 26 | ============ 27 | Contributing 28 | ============ 29 | 30 | Welcome to ``FactorGo`` contributor's guide. 31 | 32 | This document focuses on getting any potential contributor familiarized 33 | with the development processes, but `other kinds of contributions`_ are also 34 | appreciated. 35 | 36 | If you are new to using git_ or have never collaborated in a project previously, 37 | please have a look at `contribution-guide.org`_. Other resources are also 38 | listed in the excellent `guide created by FreeCodeCamp`_ [#contrib1]_. 39 | 40 | Please notice, all users and contributors are expected to be **open, 41 | considerate, reasonable, and respectful**. When in doubt, `Python Software 42 | Foundation's Code of Conduct`_ is a good reference in terms of behavior 43 | guidelines. 44 | 45 | 46 | Issue Reports 47 | ============= 48 | 49 | If you experience bugs or general issues with ``FactorGo``, please have a look 50 | on the `issue tracker`_. If you don't see anything useful there, please feel 51 | free to fire an issue report. 52 | 53 | .. tip:: 54 | Please don't forget to include the closed issues in your search. 55 | Sometimes a solution was already reported, and the problem is considered 56 | **solved**. 57 | 58 | New issue reports should include information about your programming environment 59 | (e.g., operating system, Python version) and steps to reproduce the problem. 60 | Please try also to simplify the reproduction steps to a very minimal example 61 | that still illustrates the problem you are facing. By removing other factors, 62 | you help us to identify the root cause of the issue. 63 | 64 | 65 | Documentation Improvements 66 | ========================== 67 | 68 | You can help improve ``FactorGo`` docs by making them more readable and coherent, or 69 | by adding missing information and correcting mistakes. 70 | 71 | ``FactorGo`` documentation uses Sphinx_ as its main documentation compiler. 72 | This means that the docs are kept in the same repository as the project code, and 73 | that any documentation update is done in the same way was a code contribution. 74 | 75 | .. todo:: Don't forget to mention which markup language you are using. 76 | 77 | e.g., reStructuredText_ or CommonMark_ with MyST_ extensions. 78 | 79 | .. todo:: If your project is hosted on GitHub, you can also mention the following tip: 80 | 81 | .. tip:: 82 | Please notice that the `GitHub web interface`_ provides a quick way of 83 | propose changes in ``FactorGo``'s files. While this mechanism can 84 | be tricky for normal code contributions, it works perfectly fine for 85 | contributing to the docs, and can be quite handy. 86 | 87 | If you are interested in trying this method out, please navigate to 88 | the ``docs`` folder in the source repository_, find which file you 89 | would like to propose changes and click in the little pencil icon at the 90 | top, to open `GitHub's code editor`_. Once you finish editing the file, 91 | please write a message in the form at the bottom of the page describing 92 | which changes have you made and what are the motivations behind them and 93 | submit your proposal. 94 | 95 | When working on documentation changes in your local machine, you can 96 | compile them using |tox|_:: 97 | 98 | tox -e docs 99 | 100 | and use Python's built-in web server for a preview in your web browser 101 | (``http://localhost:8000``):: 102 | 103 | python3 -m http.server --directory 'docs/_build/html' 104 | 105 | 106 | Code Contributions 107 | ================== 108 | 109 | .. todo:: Please include a reference or explanation about the internals of the project. 110 | 111 | An architecture description, design principles or at least a summary of the 112 | main concepts will make it easy for potential contributors to get started 113 | quickly. 114 | 115 | Submit an issue 116 | --------------- 117 | 118 | Before you work on any non-trivial code contribution it's best to first create 119 | a report in the `issue tracker`_ to start a discussion on the subject. 120 | This often provides additional considerations and avoids unnecessary work. 121 | 122 | Create an environment 123 | --------------------- 124 | 125 | Before you start coding, we recommend creating an isolated `virtual 126 | environment`_ to avoid any problems with your installed Python packages. 127 | This can easily be done via either |virtualenv|_:: 128 | 129 | virtualenv 130 | source /bin/activate 131 | 132 | or Miniconda_:: 133 | 134 | conda create -n FactorGo python=3 six virtualenv pytest pytest-cov 135 | conda activate FactorGo 136 | 137 | Clone the repository 138 | -------------------- 139 | 140 | #. Create an user account on |the repository service| if you do not already have one. 141 | #. Fork the project repository_: click on the *Fork* button near the top of the 142 | page. This creates a copy of the code under your account on |the repository service|. 143 | #. Clone this copy to your local disk:: 144 | 145 | git clone git@github.com:YourLogin/FactorGo.git 146 | cd FactorGo 147 | 148 | #. You should run:: 149 | 150 | pip install -U pip setuptools -e . 151 | 152 | to be able run ``putup --help``. 153 | 154 | .. todo:: if you are not using pre-commit, please remove the following item: 155 | 156 | #. Install |pre-commit|_:: 157 | 158 | pip install pre-commit 159 | pre-commit install 160 | 161 | ``FactorGo`` comes with a lot of hooks configured to automatically help the 162 | developer to check the code being written. 163 | 164 | Implement your changes 165 | ---------------------- 166 | 167 | #. Create a branch to hold your changes:: 168 | 169 | git checkout -b my-feature 170 | 171 | and start making changes. Never work on the master branch! 172 | 173 | #. Start your work on this branch. Don't forget to add docstrings_ to new 174 | functions, modules and classes, especially if they are part of public APIs. 175 | 176 | #. Add yourself to the list of contributors in ``AUTHORS.rst``. 177 | 178 | #. When you’re done editing, do:: 179 | 180 | git add 181 | git commit 182 | 183 | to record your changes in git_. 184 | 185 | .. todo:: if you are not using pre-commit, please remove the following item: 186 | 187 | Please make sure to see the validation messages from |pre-commit|_ and fix 188 | any eventual issues. 189 | This should automatically use flake8_/black_ to check/fix the code style 190 | in a way that is compatible with the project. 191 | 192 | .. important:: Don't forget to add unit tests and documentation in case your 193 | contribution adds an additional feature and is not just a bugfix. 194 | 195 | Moreover, writing a `descriptive commit message`_ is highly recommended. 196 | In case of doubt, you can check the commit history with:: 197 | 198 | git log --graph --decorate --pretty=oneline --abbrev-commit --all 199 | 200 | to look for recurring communication patterns. 201 | 202 | #. Please check that your changes don't break any unit tests with:: 203 | 204 | tox 205 | 206 | (after having installed |tox|_ with ``pip install tox`` or ``pipx``). 207 | 208 | You can also use |tox|_ to run several other pre-configured tasks in the 209 | repository. Try ``tox -av`` to see a list of the available checks. 210 | 211 | Submit your contribution 212 | ------------------------ 213 | 214 | #. If everything works fine, push your local branch to |the repository service| with:: 215 | 216 | git push -u origin my-feature 217 | 218 | #. Go to the web page of your fork and click |contribute button| 219 | to send your changes for review. 220 | 221 | .. todo:: if you are using GitHub, you can uncomment the following paragraph 222 | 223 | Find more detailed information in `creating a PR`_. You might also want to open 224 | the PR as a draft first and mark it as ready for review after the feedbacks 225 | from the continuous integration (CI) system or any required fixes. 226 | 227 | 228 | Troubleshooting 229 | --------------- 230 | 231 | The following tips can be used when facing problems to build or test the 232 | package: 233 | 234 | #. Make sure to fetch all the tags from the upstream repository_. 235 | The command ``git describe --abbrev=0 --tags`` should return the version you 236 | are expecting. If you are trying to run CI scripts in a fork repository, 237 | make sure to push all the tags. 238 | You can also try to remove all the egg files or the complete egg folder, i.e., 239 | ``.eggs``, as well as the ``*.egg-info`` folders in the ``src`` folder or 240 | potentially in the root of your project. 241 | 242 | #. Sometimes |tox|_ misses out when new dependencies are added, especially to 243 | ``setup.cfg`` and ``docs/requirements.txt``. If you find any problems with 244 | missing dependencies when running a command with |tox|_, try to recreate the 245 | ``tox`` environment using the ``-r`` flag. For example, instead of:: 246 | 247 | tox -e docs 248 | 249 | Try running:: 250 | 251 | tox -r -e docs 252 | 253 | #. Make sure to have a reliable |tox|_ installation that uses the correct 254 | Python version (e.g., 3.7+). When in doubt you can run:: 255 | 256 | tox --version 257 | # OR 258 | which tox 259 | 260 | If you have trouble and are seeing weird errors upon running |tox|_, you can 261 | also try to create a dedicated `virtual environment`_ with a |tox|_ binary 262 | freshly installed. For example:: 263 | 264 | virtualenv .venv 265 | source .venv/bin/activate 266 | .venv/bin/pip install tox 267 | .venv/bin/tox -e all 268 | 269 | #. `Pytest can drop you`_ in an interactive session in the case an error occurs. 270 | In order to do that you need to pass a ``--pdb`` option (for example by 271 | running ``tox -- -k --pdb``). 272 | You can also setup breakpoints manually instead of using the ``--pdb`` option. 273 | 274 | 275 | Maintainer tasks 276 | ================ 277 | 278 | Releases 279 | -------- 280 | 281 | .. todo:: This section assumes you are using PyPI to publicly release your package. 282 | 283 | If instead you are using a different/private package index, please update 284 | the instructions accordingly. 285 | 286 | If you are part of the group of maintainers and have correct user permissions 287 | on PyPI_, the following steps can be used to release a new version for 288 | ``FactorGo``: 289 | 290 | #. Make sure all unit tests are successful. 291 | #. Tag the current commit on the main branch with a release tag, e.g., ``v1.2.3``. 292 | #. Push the new tag to the upstream repository_, e.g., ``git push upstream v1.2.3`` 293 | #. Clean up the ``dist`` and ``build`` folders with ``tox -e clean`` 294 | (or ``rm -rf dist build``) 295 | to avoid confusion with old builds and Sphinx docs. 296 | #. Run ``tox -e build`` and check that the files in ``dist`` have 297 | the correct version (no ``.dirty`` or git_ hash) according to the git_ tag. 298 | Also check the sizes of the distributions, if they are too big (e.g., > 299 | 500KB), unwanted clutter may have been accidentally included. 300 | #. Run ``tox -e publish -- --repository pypi`` and check that everything was 301 | uploaded to PyPI_ correctly. 302 | 303 | 304 | 305 | .. [#contrib1] Even though, these resources focus on open source projects and 306 | communities, the general ideas behind collaborating with other developers 307 | to collectively create software are general and can be applied to all sorts 308 | of environments, including private companies and proprietary code bases. 309 | 310 | 311 | .. <-- strart --> 312 | .. todo:: Please review and change the following definitions: 313 | 314 | .. |the repository service| replace:: GitHub 315 | .. |contribute button| replace:: "Create pull request" 316 | 317 | .. _repository: https://github.com//FactorGo 318 | .. _issue tracker: https://github.com//FactorGo/issues 319 | .. <-- end --> 320 | 321 | 322 | .. |virtualenv| replace:: ``virtualenv`` 323 | .. |pre-commit| replace:: ``pre-commit`` 324 | .. |tox| replace:: ``tox`` 325 | 326 | 327 | .. _black: https://pypi.org/project/black/ 328 | .. _CommonMark: https://commonmark.org/ 329 | .. _contribution-guide.org: http://www.contribution-guide.org/ 330 | .. _creating a PR: https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request 331 | .. _descriptive commit message: https://chris.beams.io/posts/git-commit 332 | .. _docstrings: https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html 333 | .. _first-contributions tutorial: https://github.com/firstcontributions/first-contributions 334 | .. _flake8: https://flake8.pycqa.org/en/stable/ 335 | .. _git: https://git-scm.com 336 | .. _GitHub's fork and pull request workflow: https://guides.github.com/activities/forking/ 337 | .. _guide created by FreeCodeCamp: https://github.com/FreeCodeCamp/how-to-contribute-to-open-source 338 | .. _Miniconda: https://docs.conda.io/en/latest/miniconda.html 339 | .. _MyST: https://myst-parser.readthedocs.io/en/latest/syntax/syntax.html 340 | .. _other kinds of contributions: https://opensource.guide/how-to-contribute 341 | .. _pre-commit: https://pre-commit.com/ 342 | .. _PyPI: https://pypi.org/ 343 | .. _PyScaffold's contributor's guide: https://pyscaffold.org/en/stable/contributing.html 344 | .. _Pytest can drop you: https://docs.pytest.org/en/stable/usage.html#dropping-to-pdb-python-debugger-at-the-start-of-a-test 345 | .. _Python Software Foundation's Code of Conduct: https://www.python.org/psf/conduct/ 346 | .. _reStructuredText: https://www.sphinx-doc.org/en/master/usage/restructuredtext/ 347 | .. _Sphinx: https://www.sphinx-doc.org/en/master/ 348 | .. _tox: https://tox.readthedocs.io/en/stable/ 349 | .. _virtual environment: https://realpython.com/python-virtual-environments-a-primer/ 350 | .. _virtualenv: https://virtualenv.pypa.io/en/stable/ 351 | 352 | .. _GitHub web interface: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository 353 | .. _GitHub's code editor: https://docs.github.com/en/github/managing-files-in-a-repository/managing-files-on-github/editing-files-in-your-repository 354 | -------------------------------------------------------------------------------- /src/factorgo/infer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import NamedTuple, Tuple, Union 3 | 4 | import jax.numpy as jnp 5 | import jax.numpy.linalg as jnpla 6 | import jax.scipy.special as scp 7 | from jax import jit, random 8 | 9 | 10 | @dataclass 11 | class Options: 12 | """simple class to store options for stopping rule (tolerance)""" 13 | 14 | elbo_tol: float = 1e-3 15 | max_iter: int = 10000 16 | 17 | 18 | class HyperParams: 19 | """simple class to store options for hyper-parameters""" 20 | 21 | halpha_a: float = 1e-5 22 | halpha_b: float = 1e-5 23 | htau_a: float = 1e-5 24 | htau_b: float = 1e-5 25 | hbeta: float = 1e-5 26 | 27 | 28 | class ZState(NamedTuple): 29 | """simple class to store results of posterior mean and variance for Z""" 30 | 31 | Z_m: jnp.ndarray 32 | Z_var: jnp.ndarray 33 | 34 | 35 | class WState(NamedTuple): 36 | """simple class to store results of posterior mean and variance for W""" 37 | 38 | W_m: jnp.ndarray 39 | W_var: jnp.ndarray 40 | 41 | 42 | class MuState(NamedTuple): 43 | """simple class to store results of posterior mean and variance (scalar) for Mu""" 44 | 45 | Mu_m: jnp.ndarray 46 | Mu_var: Union[float, jnp.ndarray] 47 | 48 | 49 | class AlphaState(NamedTuple): 50 | """simple class to store options for hyper-parameters and moments for Alpha""" 51 | 52 | phalpha_a: Union[float, jnp.ndarray] 53 | phalpha_b: jnp.ndarray 54 | Ealpha: jnp.ndarray 55 | Elog_alpha: jnp.ndarray 56 | 57 | 58 | class TauState(NamedTuple): 59 | """simple class to store options for hyper-parameters and moments for Tau""" 60 | 61 | phtau_a: Union[float, jnp.ndarray] 62 | phtau_b: Union[float, jnp.ndarray] 63 | Etau: Union[float, jnp.ndarray] 64 | Elog_tau: Union[float, jnp.ndarray] 65 | 66 | 67 | class JointState(NamedTuple): 68 | """simple class to store options for joint update of variables""" 69 | 70 | Z_m: jnp.ndarray 71 | Z_var: jnp.ndarray 72 | W_m: jnp.ndarray 73 | W_var: jnp.ndarray 74 | Mu_m: jnp.ndarray 75 | phalpha_b: jnp.ndarray 76 | Ealpha: jnp.ndarray 77 | Elog_alpha: jnp.ndarray 78 | 79 | 80 | class AuxState(NamedTuple): 81 | """simple class to store options for auxilary parameters""" 82 | 83 | b: jnp.ndarray 84 | R: jnp.ndarray 85 | R_inv: jnp.ndarray 86 | 87 | 88 | def get_init(key_init, k, dat, log, init_opt: str = "random") -> Tuple: 89 | """initialize matrix for inference 90 | only need to initialize parameters required for updating Z, mu and W. 91 | quantities not used for updating those parameters are initialized with zeros 92 | as placeholder, with shape for jit to allocate memory 93 | """ 94 | 95 | n, p = dat.shape 96 | w_shape = (p, k) 97 | 98 | W_var_init = jnp.identity(k) 99 | 100 | if init_opt == "svd": 101 | U, D, Vh = jnpla.svd(dat, full_matrices=False) 102 | W_m_init = Vh[0:k, :].T 103 | log.info("Initialize W using tSVD.") 104 | elif init_opt == "random": 105 | key_init, key_w = random.split(key_init) 106 | W_m_init = random.normal(key_w, shape=w_shape) 107 | else: 108 | log.info(f"{init_opt} is not supported") 109 | exit() 110 | 111 | Mu_m = jnp.zeros((p,)) 112 | 113 | Ealpha_init = jnp.repeat(HyperParams.halpha_a / HyperParams.halpha_b, k) 114 | 115 | Etau = HyperParams.htau_a / HyperParams.htau_b 116 | 117 | # key quantities and placeholders 118 | wstate = WState(W_m_init, W_var_init) 119 | mustate = MuState(Mu_m, jnp.array([0.0])) 120 | alphastate = AlphaState( 121 | jnp.array([0.0]), jnp.zeros((k,)), Ealpha_init, jnp.zeros((k,)) 122 | ) 123 | taustate = TauState(jnp.array([0.0]), jnp.array([0.0]), Etau, jnp.array([0.0])) 124 | 125 | return wstate, mustate, alphastate, taustate 126 | 127 | 128 | def batched_trace(A): 129 | """calculate trace for each batched matrice""" 130 | # bii->b represents doing batched trace operations 131 | return jnp.einsum("bii->b", A) 132 | 133 | 134 | def calc_MeanQuadForm(wstate, WtW, zstate, mustate, B, sampleN, sampleN_sqrt): 135 | """calculate quadratic term E[(Zscore - fitted)^2]""" 136 | # import pdb; pdb.set_trace() 137 | W_m, _ = wstate 138 | Z_m, Z_var = zstate 139 | Mu_m, Mu_var = mustate 140 | p, _ = W_m.shape 141 | 142 | term1 = jnp.sum(B * B) 143 | term2 = jnp.sum(sampleN) * (p * Mu_var + Mu_m.T @ Mu_m) 144 | term3 = jnp.sum( 145 | sampleN 146 | * (batched_trace(WtW @ Z_var) + jnp.einsum("ni,ik,nk->n", Z_m, WtW, Z_m)) 147 | ) 148 | term4 = 2 * jnp.sum((Mu_m.T @ W_m) @ (sampleN[:, jnp.newaxis] * Z_m).T) 149 | term5 = 2 * jnp.trace(sampleN_sqrt[:, jnp.newaxis] * ((B @ W_m) @ Z_m.T)) 150 | term6 = 2 * jnp.sum((B @ Mu_m) * sampleN_sqrt) 151 | 152 | mean_quad_form = term1 + term2 + term3 + term4 - term5 - term6 153 | 154 | return mean_quad_form 155 | 156 | 157 | def logdet(M): 158 | """calculate log determinant for each batched matrice""" 159 | return jnpla.slogdet(M)[1] 160 | 161 | 162 | # Update Posterior Moments 163 | def pZ_main(B, wstate, EWtW, mustate, taustate, sampleN, sampleN_sqrt): 164 | """update posterior moments for factor score Z 165 | :pZ_m: (n,k) posterior moments 166 | :pZ_var: (n,k,k) posterior kxk covariance matrice for each study i 167 | """ 168 | W_m, _ = wstate 169 | Mu_m, _ = mustate 170 | _, _, Etau, _ = taustate 171 | 172 | n, p = B.shape 173 | _, k = W_m.shape 174 | 175 | pZ_var = jnpla.inv( 176 | ((Etau * EWtW)[:, :, jnp.newaxis] * sampleN).swapaxes(-1, 0) + jnp.eye(k) 177 | ) 178 | Bres = jnp.reshape((B / sampleN_sqrt[:, None] - Mu_m) * Etau, (n, p, 1)) 179 | pZ_m = (pZ_var @ (W_m.T @ Bres)).squeeze(-1) * sampleN[:, None] 180 | 181 | return ZState(pZ_m, pZ_var) 182 | 183 | 184 | def pMu_main(B, wstate, zstate, taustate, sampleN, sampleN_sqrt): 185 | """update posterior moments for intercept Mu 186 | :pMu_m: (p,) 187 | :pMu_var: a scalar (shared by all snps) 188 | """ 189 | W_m, _ = wstate 190 | Z_m, _ = zstate 191 | _, _, Etau, _ = taustate 192 | 193 | sum_N = jnp.sum(sampleN) 194 | pMu_var = 1 / (HyperParams.hbeta + Etau * sum_N) 195 | ZWt = Z_m @ W_m.T 196 | res_sum = jnp.sum(sampleN[:, None] * (B / sampleN_sqrt[:, None] - ZWt), axis=0) 197 | pMu_m = Etau * pMu_var * res_sum 198 | 199 | return MuState(pMu_m, pMu_var) 200 | 201 | 202 | def pW_main(B, zstate, mustate, taustate, alphastate, sampleN, sampleN_sqrt): 203 | """update posterior moments for factor loading W 204 | :pW_m: pxk 205 | :pW_V: kxk covariance matrice shared by all snps 206 | """ 207 | Z_m, Z_var = zstate 208 | Mu_m, _ = mustate 209 | _, _, Etau, _ = taustate 210 | _, _, Ealpha, _ = alphastate 211 | 212 | n, _ = Z_m.shape 213 | Bres = B / sampleN_sqrt[:, None] - Mu_m 214 | tmp = Z_var.T @ sampleN + (Z_m.T * sampleN) @ Z_m 215 | pW_V = jnp.linalg.inv(Etau * tmp + jnp.diag(Ealpha)) 216 | pW_m = jnp.einsum( 217 | "ik,np,nk->pi", Etau * pW_V, Bres, Z_m * sampleN[:, None], optimize="greedy" 218 | ) 219 | 220 | return WState(pW_m, pW_V) 221 | 222 | 223 | def palpha_main(WtW, p): 224 | """update posterior moments for ARD parameter alpha 225 | :phalpha_a: a scalar shared by all k latent factors 226 | :phalpha_b: (k,) 227 | """ 228 | 229 | phalpha_a = HyperParams.halpha_a + p * 0.5 230 | phalpha_b = HyperParams.halpha_b + 0.5 * jnp.diagonal(WtW) 231 | 232 | Ealpha = phalpha_a / phalpha_b 233 | Elog_alpha = scp.digamma(phalpha_a) - jnp.log(phalpha_b) 234 | 235 | return AlphaState(phalpha_a, phalpha_b, Ealpha, Elog_alpha) 236 | 237 | 238 | def get_aux(zstate, wstate, EWtW, taustate, sampleN): 239 | """find auxillary parameters for transformation method""" 240 | pZ_m, pZ_var = zstate 241 | pW_m, pW_var = wstate 242 | _, _, Etau, _ = taustate 243 | n, k = pZ_m.shape 244 | p, _ = pW_m.shape 245 | 246 | # 1) find b 247 | psi_n = Etau * p * pW_var 248 | # !! this can be simplified 249 | Psi = jnp.broadcast_to(psi_n[jnp.newaxis, ...], (n, k, k)) * sampleN.reshape( 250 | (n, 1, 1) 251 | ) + jnp.eye(k) 252 | Psi_Z = (Psi @ pZ_m.reshape((n, k, 1))).squeeze(-1) 253 | b = jnpla.inv(jnp.sum(Psi, axis=0)) @ jnp.sum(Psi_Z, axis=0) 254 | 255 | # 2) find R 256 | EZtZ = jnp.sum(pZ_var, axis=0) + pZ_m.T @ pZ_m 257 | # use jnpla.eigh() due to gpu end complains about jnpla.eig() 258 | # the result should be close subject to different ordering 259 | Lambda2, U = jnpla.eigh(EZtZ / n, symmetrize_input=False) 260 | U_weight = U * jnp.sqrt(Lambda2) 261 | 262 | quad_W = U_weight.T @ EWtW @ U_weight 263 | _, V = jnpla.eigh(quad_W, symmetrize_input=False) 264 | 265 | R = U_weight @ V 266 | R_inv = V.T / jnp.sqrt(Lambda2) @ U.T 267 | 268 | return AuxState(b, R, R_inv) 269 | 270 | 271 | def pjoint_main(zstate, wstate, EWtW, mustate, alphastate, auxstate): 272 | """jointly transform latent space""" 273 | pZ_m, pZ_var = zstate 274 | pW_m, pW_var = wstate 275 | pMu_m, pMu_var = mustate 276 | phalpha_a, phalpha_b, Ealpha, Elogalpha = alphastate 277 | b, R, R_inv = auxstate 278 | p, _ = pW_m.shape 279 | 280 | # 1) remove bias 281 | pZ_m_center = pZ_m - b 282 | pMu_m_center = pMu_m + pW_m @ b 283 | 284 | # 2) rotate: each row of pW_m (pxk) and each row of pZ_m (nxk) 285 | pW_m_rot = (R.T @ pW_m.T).T 286 | pW_var_rot = R.T @ pW_var @ R 287 | 288 | # rotate each of row of pZ_m (nxk) 289 | pZ_m_rot = (R_inv @ pZ_m_center.T).T 290 | pZ_var_rot = (R_inv @ pZ_var) @ R_inv.T 291 | 292 | EWtW_q = R.T @ EWtW @ R 293 | phalpha_b_rot = HyperParams.halpha_b + 0.5 * jnp.diag(EWtW_q) 294 | Ealpha_rot = phalpha_a / phalpha_b_rot 295 | Elog_alpha_rot = scp.digamma(phalpha_a.real) - jnp.log(phalpha_b_rot.real) 296 | 297 | zstate_new = ZState(pZ_m_rot, pZ_var_rot) 298 | wstate_new = WState(pW_m_rot, pW_var_rot) 299 | mustate_new = MuState(pMu_m_center, pMu_var) 300 | alphastate_new = AlphaState(phalpha_a, phalpha_b_rot, Ealpha_rot, Elog_alpha_rot) 301 | 302 | return wstate_new, zstate_new, mustate_new, alphastate_new 303 | 304 | 305 | def ptau_main(mean_quad, n, p): 306 | """update posterior moments for global scaling parameter Tau""" 307 | phtau_a = HyperParams.htau_a + n * p * 0.5 308 | phtau_b = 0.5 * mean_quad + HyperParams.htau_b 309 | 310 | Etau = phtau_a / phtau_b 311 | Elog_tau = scp.digamma(phtau_a) - jnp.log(phtau_b) 312 | 313 | return TauState(phtau_a, phtau_b, Etau, Elog_tau) 314 | 315 | 316 | # write function to call all updating function 317 | @jit 318 | def runVB( 319 | B, 320 | wstate_old, 321 | EWtW_old, 322 | mustate_old, 323 | alphastate_old, 324 | taustate_old, 325 | sampleN, 326 | sampleN_sqrt, 327 | ): 328 | """One updating step in the recursive loop""" 329 | n, p = B.shape 330 | 331 | zstate = pZ_main( 332 | B, wstate_old, EWtW_old, mustate_old, taustate_old, sampleN, sampleN_sqrt 333 | ) 334 | mustate = pMu_main(B, wstate_old, zstate, taustate_old, sampleN, sampleN_sqrt) 335 | wstate = pW_main( 336 | B, zstate, mustate, taustate_old, alphastate_old, sampleN, sampleN_sqrt 337 | ) 338 | EWtW = p * wstate.W_var + wstate.W_m.T @ wstate.W_m 339 | 340 | alphastate = palpha_main(EWtW, p) 341 | 342 | # find aux params b and R: 343 | auxstate = get_aux(zstate, wstate, EWtW, taustate_old, sampleN) 344 | wstate, zstate, mustate, alphastate = pjoint_main( 345 | zstate, wstate, EWtW, mustate, alphastate, auxstate 346 | ) 347 | 348 | EWtW = p * wstate.W_var + wstate.W_m.T @ wstate.W_m 349 | 350 | mean_quad = calc_MeanQuadForm( 351 | wstate, EWtW, zstate, mustate, B, sampleN, sampleN_sqrt 352 | ) 353 | taustate = ptau_main(mean_quad, n, p) 354 | 355 | return wstate, EWtW, zstate, mustate, alphastate, taustate, mean_quad 356 | 357 | 358 | # ELBO functions 359 | def KL_QW(W_m, W_var, Ealpha, Elog_alpha): 360 | """KL divergence between estimated posterior W and prior""" 361 | p, k = W_m.shape 362 | kl_qw = -0.5 * jnp.sum( 363 | logdet(W_var) 364 | + k 365 | + jnp.sum(Elog_alpha) 366 | - jnp.trace(Ealpha * W_var) 367 | - jnp.sum(W_m * Ealpha * W_m, axis=1) 368 | ) 369 | return kl_qw 370 | 371 | 372 | def KL_QZ(Z_m, Z_var): 373 | """KL divergence between estimated posterior Z and prior""" 374 | n, k = Z_m.shape 375 | kl_qz = 0.5 * jnp.sum( 376 | batched_trace(Z_var) + jnp.sum(Z_m * Z_m, axis=1) - k - logdet(Z_var) 377 | ) 378 | return kl_qz 379 | 380 | 381 | def KL_QMu(Mu_m, Mu_var): 382 | """KL divergence between estimated posterior Mu and prior""" 383 | p = Mu_m.size 384 | kl_qmu = 0.5 * ( 385 | jnp.sum(HyperParams.hbeta * Mu_var) 386 | + HyperParams.hbeta * (Mu_m.T @ Mu_m) 387 | - p 388 | - p * jnp.log(HyperParams.hbeta) 389 | - jnp.sum(jnp.log(Mu_var)) 390 | ) 391 | return kl_qmu 392 | 393 | 394 | def KL_gamma(pa, pb, ha, hb): 395 | """KL divergence between two gamma distributions""" 396 | kl_gamma = ( 397 | (pa - ha) * scp.digamma(pa) 398 | - scp.gammaln(pa) 399 | + scp.gammaln(ha) 400 | + ha * (jnp.log(pb) - jnp.log(hb)) 401 | + pa * ((hb - pb) / pb) 402 | ) 403 | return kl_gamma 404 | 405 | 406 | def KL_Qalpha(pa, pb): 407 | """KL divergence between estimated posterior Alpha and prior""" 408 | kl_qa = jnp.sum(KL_gamma(pa, pb, HyperParams.halpha_a, HyperParams.halpha_b)) 409 | return kl_qa 410 | 411 | 412 | def KL_Qtau(pa, pb): 413 | """KL divergence between estimated posterior Tau and prior""" 414 | kl_qtau = KL_gamma(pa, pb, HyperParams.htau_a, HyperParams.htau_b) 415 | return kl_qtau 416 | 417 | 418 | @jit 419 | def elbo(B, wstate, zstate, mustate, alphastate, taustate, mean_quad): 420 | """Calculate ELBO""" 421 | W_m, W_var = wstate 422 | Z_m, Z_var = zstate 423 | Mu_m, Mu_var = mustate 424 | phtau_a, phtau_b, Etau, Elog_tau = taustate 425 | phalpha_a, phalpha_b, Ealpha, Elog_alpha = alphastate 426 | n, p = B.shape 427 | 428 | pD = 0.5 * (n * p * Elog_tau - Etau * mean_quad) 429 | 430 | kl_qw = KL_QW(W_m, W_var, Ealpha, Elog_alpha) 431 | kl_qz = KL_QZ(Z_m, Z_var) 432 | kl_qmu = KL_QMu(Mu_m, Mu_var) 433 | kl_qa = KL_Qalpha(phalpha_a, phalpha_b) 434 | kl_qt = KL_Qtau(phtau_a, phtau_b) 435 | elbo_sum = pD - (kl_qw + kl_qz + kl_qmu + kl_qa + kl_qt) 436 | 437 | return ( 438 | elbo_sum.real, 439 | pD.real, 440 | kl_qw.real, 441 | kl_qz.real, 442 | kl_qmu.real, 443 | kl_qa.real, 444 | kl_qt.real, 445 | ) 446 | 447 | 448 | # calculate R2 for ordered factors: exausted memory 449 | def R2(B, W_m, Z_m, Etau, sampleN_sqrt): 450 | """Calculate variance explained by each inferred factor 451 | Use residuals to calculate this. 452 | """ 453 | 454 | n, p = B.shape 455 | _, k = Z_m.shape 456 | 457 | tss = jnp.sum(B * B) * Etau 458 | 459 | sse = jnp.zeros((k,)) 460 | for i in range(n): 461 | WZ = W_m * Z_m[i] * sampleN_sqrt[i] # pxk 462 | res = B[i][:, None] - WZ # pxk 463 | tmp = jnp.sum(res * res, axis=0) # (k,) 464 | sse += tmp 465 | r2 = 1.0 - sse * Etau / tss 466 | 467 | return r2 468 | 469 | 470 | @jit 471 | def _inner_fit( 472 | B, 473 | wstate_old, 474 | EWtW_old, 475 | mustate_old, 476 | alphastate_old, 477 | taustate_old, 478 | sampleN, 479 | sampleN_sqrt, 480 | ): 481 | """update parameters and calculate ELBO""" 482 | wstate, EWtW, zstate, mustate, alphastate, taustate, mean_quad = runVB( 483 | B, 484 | wstate_old, 485 | EWtW_old, 486 | mustate_old, 487 | alphastate_old, 488 | taustate_old, 489 | sampleN, 490 | sampleN_sqrt, 491 | ) 492 | 493 | check_elbo, pD, kl_qw, kl_qz, kl_qmu, kl_qa, kl_qt = elbo( 494 | B, wstate, zstate, mustate, alphastate, taustate, mean_quad 495 | ) 496 | 497 | return wstate, EWtW, zstate, mustate, alphastate, taustate, check_elbo 498 | 499 | 500 | def fit(B, args, k, key_init, log, options, sampleN, sampleN_sqrt): 501 | """Wrapper function for running factorgo""" 502 | # set initializers 503 | _, p_snps = B.shape 504 | 505 | log.info(f"Initializing mean parameters with seed {args.seed}.") 506 | wstate, mustate, alphastate, taustate = get_init( 507 | key_init, k, B, log, args.init_factor 508 | ) 509 | EWtW = p_snps * wstate.W_var + wstate.W_m.T @ wstate.W_m 510 | 511 | log.info("Completed initialization.") 512 | f_finfo = jnp.finfo(float) # Machine limits for floating point types 513 | oelbo, delbo = f_finfo.min, f_finfo.max 514 | 515 | log.info("Starting Variational inference.") 516 | log.info("first iter may be slow due to JIT compilation).") 517 | RATE = args.rate # print per RATE iterations 518 | for idx in range(options.max_iter): 519 | 520 | wstate, EWtW, zstate, mustate, alphastate, taustate, check_elbo = _inner_fit( 521 | B, wstate, EWtW, mustate, alphastate, taustate, sampleN, sampleN_sqrt 522 | ) 523 | 524 | delbo = check_elbo - oelbo 525 | oelbo = check_elbo 526 | if idx % RATE == 0: 527 | log.info( 528 | f"itr={idx} | Elbo = {check_elbo} | deltaElbo = {delbo} | Tau = {taustate.Etau}" 529 | ) 530 | 531 | if jnp.fabs(delbo) < args.elbo_tol: 532 | break 533 | 534 | r2 = R2(B, wstate.W_m, zstate.Z_m, taustate.Etau, sampleN_sqrt) 535 | f_order = jnp.argsort(-jnp.abs(r2)) # get index for sorting 536 | ordered_r2 = r2[f_order] 537 | ordered_Z_m = zstate.Z_m[:, f_order] 538 | ordered_W_m = wstate.W_m[:, f_order] 539 | 540 | f_info = jnp.column_stack( 541 | (jnp.arange(k) + 1, alphastate.Ealpha.real[f_order], ordered_r2) 542 | ) 543 | log.info(f"Finished inference after {idx} iterations.") 544 | log.info(f"Final elbo = {check_elbo} and resid precision Tau = {taustate.Etau}") 545 | log.info(f"Final sorted Ealpha = {jnp.sort(alphastate.Ealpha)}") 546 | log.info(f"Final sorted R2 = {ordered_r2}") 547 | 548 | return ( 549 | wstate.W_m, 550 | wstate.W_var, 551 | zstate.Z_m, 552 | zstate.Z_var, 553 | f_info, 554 | f_order, 555 | ordered_W_m, 556 | ordered_Z_m, 557 | ) 558 | -------------------------------------------------------------------------------- /example/runFactorGo.old.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import argparse as ap 3 | import logging 4 | import os 5 | import sys 6 | from dataclasses import dataclass 7 | from typing import NamedTuple, Union 8 | 9 | import numpy as np 10 | import pandas as pd 11 | 12 | import jax 13 | import jax.numpy as jnp 14 | import jax.numpy.linalg as jnpla 15 | import jax.scipy.special as scp 16 | from jax import jit, random 17 | 18 | # import jax.profiler 19 | 20 | # server = jax.profiler.start_server(9999) 21 | # jax.profiler.start_trace("/testres") 22 | 23 | # disable jax preallocation or set the % of memory for preallocation 24 | # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false' 25 | # os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.10' 26 | 27 | 28 | def get_logger(name, path=None): 29 | logger = logging.getLogger(name) 30 | if not logger.handlers: 31 | # Prevent logging from propagating to the root logger 32 | logger.propagate = 0 33 | console = logging.StreamHandler() 34 | logger.addHandler(console) 35 | 36 | log_format = "[%(asctime)s - %(levelname)s] %(message)s" 37 | date_format = "%Y-%m-%d %H:%M:%S" 38 | formatter = logging.Formatter(fmt=log_format, datefmt=date_format) 39 | console.setFormatter(formatter) 40 | 41 | if path is not None: 42 | disk_log_stream = open("{}.log".format(path), "w") 43 | disk_handler = logging.StreamHandler(disk_log_stream) 44 | logger.addHandler(disk_handler) 45 | disk_handler.setFormatter(formatter) 46 | 47 | return logger 48 | 49 | 50 | ## set platform for jax 51 | def set_platform(platform=None): 52 | """ 53 | Changes platform to CPU, GPU, or TPU. This utility only takes 54 | effect at the beginning of your program. 55 | 56 | :param str platform: either 'cpu', 'gpu', or 'tpu'. 57 | """ 58 | if platform is None: 59 | platform = os.getenv("JAX_PLATFORM_NAME", "cpu") 60 | jax.config.update("jax_platform_name", platform) 61 | return 62 | 63 | 64 | @dataclass 65 | class Options: 66 | """simple class to store options for stopping rule""" 67 | 68 | # tolerance 69 | elbo_tol: float = 1e-3 70 | tau_tol: float = 1e-3 71 | max_iter: int = 1000 72 | 73 | 74 | class HyperParams: 75 | """simple class to store options for hyper-parameters""" 76 | 77 | halpha_a: float = 1e-5 78 | halpha_b: float = 1e-5 79 | htau_a: float = 1e-5 80 | htau_b: float = 1e-5 81 | hbeta: float = 1e-5 82 | 83 | 84 | class InitParams(NamedTuple): 85 | """simple class to store options for initialization of matrix""" 86 | 87 | W_m: jnp.ndarray 88 | W_var: jnp.ndarray 89 | Mu_m: jnp.ndarray 90 | Ealpha: jnp.ndarray # 1d 91 | Etau: Union[float, jnp.ndarray] # for scalars 92 | 93 | 94 | class ZState(NamedTuple): 95 | """simple class to store results of orthogonalization + projection""" 96 | 97 | Z_m: jnp.ndarray 98 | Z_var: jnp.ndarray 99 | 100 | 101 | class WState(NamedTuple): 102 | """simple class to store results of orthogonalization + projection""" 103 | 104 | W_m: jnp.ndarray 105 | W_var: Union[float, jnp.ndarray] 106 | 107 | 108 | class MuState(NamedTuple): 109 | """simple class to store results of orthogonalization + projection""" 110 | 111 | Mu_m: jnp.ndarray 112 | Mu_var: jnp.ndarray 113 | 114 | 115 | class AlphaState(NamedTuple): 116 | """simple class to store options for hyper-parameters and Ealpha""" 117 | 118 | phalpha_a: jnp.ndarray 119 | phalpha_b: jnp.ndarray 120 | Ealpha: jnp.ndarray 121 | Elog_alpha: jnp.ndarray 122 | 123 | 124 | class TauState(NamedTuple): 125 | """simple class to store options for hyper-parameters and Etau""" 126 | 127 | phtau_a: jnp.ndarray 128 | phtau_b: jnp.ndarray 129 | Etau: Union[float, jnp.ndarray] 130 | Elog_tau: Union[float, jnp.ndarray] 131 | 132 | 133 | class JointState(NamedTuple): 134 | """simple class to store options for Joint update of variables""" 135 | 136 | Z_m: jnp.ndarray 137 | Z_var: jnp.ndarray 138 | W_m: jnp.ndarray 139 | W_var: jnp.ndarray 140 | Mu_m: jnp.ndarray 141 | phalpha_b: jnp.ndarray 142 | Ealpha: jnp.ndarray 143 | Elog_alpha: jnp.ndarray 144 | 145 | 146 | def read_data(z_path, N_path, log, removeN=False, scaledat=True): 147 | """ 148 | input z score summary stats: headers are ["snp", "trait1", "trait2", ..., "traitn"] 149 | input sample size file: one column of sample size (with header) which has the same order as above 150 | """ 151 | 152 | # Read dataset (read as data frame) 153 | df_z = pd.read_csv(z_path, delimiter="\t", header=0) 154 | snp_col = df_z.columns[0] 155 | 156 | # drop the first column (axis = 1) and convert to nxp 157 | # snp_col = df_z.columns[0] 158 | df_z.drop(labels=[snp_col], axis=1, inplace=True) 159 | df_z = df_z.astype("float").T 160 | 161 | if scaledat: 162 | df_z = df_z.subtract(df_z.mean()) 163 | df_z = df_z.divide(df_z.std()) 164 | log.info("Scale SNPs to mean zero and sd 1") 165 | 166 | # convert to numpy/jax device-array (n,p) 167 | df_z = jnp.array(df_z) 168 | 169 | # read sample size file and convert str into numerics, convert to nxp matrix 170 | df_N = pd.read_csv(N_path, delimiter="\t", header=0) 171 | df_N = df_N.astype("float") 172 | # convert sampleN (a file with one column and header)to arrays 173 | N_col = df_N.columns[0] 174 | sampleN = df_N[N_col].values 175 | sampleN_sqrt = jnp.sqrt(sampleN) 176 | 177 | if removeN: 178 | n, _ = df_z.shape 179 | sampleN = jnp.ones((n,)) 180 | sampleN_sqrt = jnp.ones((n,)) 181 | log.info("Remove N from model, set all N == 1.") 182 | 183 | return df_z, sampleN, sampleN_sqrt 184 | 185 | 186 | def get_init(key_init, n, p, k, dat, log, init_opt="random"): 187 | """ 188 | initialize matrix for inference 189 | We update moments for Z and W first, so here only initiaze parameters required for updating those 190 | """ 191 | w_shape = (p, k) 192 | z_shape = (n, k) 193 | 194 | W_var_init = jnp.identity(k) 195 | # Z_var_init = jnp.broadcast_to(jnp.identity(k)[jnp.newaxis, ...], (n, k, k)) 196 | 197 | if init_opt == "svd": 198 | U, D, Vh = jnpla.svd(dat, full_matrices=False) 199 | W_m_init = Vh[0:k, :].T 200 | Z_m_init = U[:, 0:k] * D[0:k] 201 | log.info("Initialize W and Z using tsvd.") 202 | else: 203 | key_init, key_w = random.split(key_init) 204 | W_m_init = random.normal(key_w, shape=w_shape) 205 | 206 | key_init, key_z = random.split(key_init) 207 | # Z_m_init = random.normal(key_z, shape=z_shape) 208 | 209 | Mu_m = jnp.zeros((p,)) 210 | 211 | Ealpha_init = jnp.repeat(HyperParams.halpha_a / HyperParams.halpha_b, k) 212 | 213 | Etau = HyperParams.htau_a / HyperParams.htau_b 214 | 215 | return InitParams( 216 | W_m=W_m_init, 217 | W_var=W_var_init, 218 | Mu_m=Mu_m, 219 | Ealpha=Ealpha_init, 220 | Etau=Etau, 221 | ) 222 | 223 | 224 | ## self defined function to do computation and keep batch 225 | ## note: not using this batched_WtVinvW() anymore b/c SE is dropped 226 | # def batched_WtVinvW(W_m, W_var, sampleN): 227 | # # out WtVinvW: nxkxk 228 | # # each kxk only differ by a factor of N 229 | # return jnp.einsum( 230 | # "n,bik->nik", sampleN, W_var + batched_outer(W_m, W_m), optimize="greedy" 231 | # ) 232 | 233 | 234 | def batched_outer(A, B): 235 | # bij,bjk->bik is batched outer product (outer product of each row) 236 | return jnp.einsum("bi,bk->bik", A, B) 237 | 238 | 239 | def batched_inner(A, B): 240 | # bij,bij->b represents doing batched inner products 241 | return jnp.einsum("bij,bij->b", A, B) 242 | 243 | 244 | def batched_trace(A): 245 | # bii->b represents doing batched trace operations 246 | return jnp.einsum("bii->b", A) 247 | 248 | 249 | def batched_broadcast(A, B): 250 | # each element in A (1d array) times each row of B 251 | return jnp.einsum("i,ij->ij", A, B) 252 | 253 | 254 | def calc_MeanQuadForm(W_m, WtW, Z_m, Z_var, Mu_m, Mu_var, B, sampleN, sampleN_sqrt): 255 | # import pdb; pdb.set_trace() 256 | p, _ = W_m.shape 257 | term1 = jnp.sum(B * B) 258 | term2 = jnp.sum(sampleN) * (p * Mu_var + Mu_m.T @ Mu_m) 259 | term3 = jnp.sum( 260 | sampleN 261 | * (batched_trace(WtW @ Z_var) + jnp.einsum("ni,ik,nk->n", Z_m, WtW, Z_m)) 262 | ) 263 | term4 = 2 * jnp.sum((Mu_m.T @ W_m) @ (sampleN[:, jnp.newaxis] * Z_m).T) 264 | term5 = 2 * jnp.trace(sampleN_sqrt[:, jnp.newaxis] * ((B @ W_m) @ Z_m.T)) 265 | term6 = 2 * jnp.sum((B @ Mu_m) * sampleN_sqrt) 266 | 267 | mean_quad_form = term1 + term2 + term3 + term4 - term5 - term6 268 | 269 | return mean_quad_form 270 | 271 | 272 | def logdet(M): 273 | """ 274 | calculate logdet for each batched matrice 275 | """ 276 | return jnpla.slogdet(M)[1] 277 | 278 | 279 | ## Update Posterior Moments 280 | # @jit 281 | def pZ_main(B, W_m, EWtW, Mu_m, Etau, sampleN, sampleN_sqrt): 282 | """ 283 | :pZ_m: (n,k) posterior moments 284 | :pZ_var: (n,k,k) posterior kxk covariance matrice for each study i 285 | """ 286 | n, p = B.shape 287 | _, k = W_m.shape 288 | pZ_var = jnpla.inv( 289 | ((Etau * EWtW)[:, :, jnp.newaxis] * sampleN).swapaxes(-1, 0) + jnp.eye(k) 290 | ) 291 | Bres = jnp.reshape((B / sampleN_sqrt[:, None] - Mu_m) * Etau, (n, p, 1)) 292 | pZ_m = (pZ_var @ (W_m.T @ Bres)).squeeze(-1) * sampleN[:, None] 293 | 294 | return ZState(pZ_m, pZ_var) 295 | 296 | 297 | # @jit 298 | def pMu_main(B, W_m, Z_m, Etau, sampleN, sampleN_sqrt): 299 | """ 300 | :pMu_m: (p,) 301 | :pMu_var: a scalar (shared by all snps) 302 | """ 303 | sum_N = jnp.sum(sampleN) 304 | pMu_var = 1 / (HyperParams.hbeta + Etau * sum_N) 305 | ZWt = Z_m @ W_m.T 306 | res_sum = jnp.sum(sampleN[:, None] * (B / sampleN_sqrt[:, None] - ZWt), axis=0) 307 | pMu_m = Etau * pMu_var * res_sum 308 | 309 | return MuState(pMu_m, pMu_var) 310 | 311 | 312 | # @jit 313 | def pW_main(B, Z_m, Z_var, Mu_m, Etau, Ealpha, sampleN, sampleN_sqrt): 314 | """ 315 | :pW_m: pxk 316 | :pW_V: kxk covariance matrice shared by all snps 317 | """ 318 | n, _ = Z_m.shape 319 | Bres = B / sampleN_sqrt[:, None] - Mu_m 320 | tmp = Z_var.T @ sampleN + (Z_m.T * sampleN) @ Z_m 321 | pW_V = jnp.linalg.inv(Etau * tmp + jnp.diag(Ealpha)) 322 | pW_m = jnp.einsum( 323 | "ik,np,nk->pi", Etau * pW_V, Bres, Z_m * sampleN[:, None], optimize="greedy" 324 | ) 325 | 326 | return WState(pW_m, pW_V) 327 | 328 | 329 | # @jit 330 | def palpha_main(WtW, p): 331 | """ 332 | :phalpha_a: shared by all k latent factirs 333 | :phalpha_b: (k,) 334 | """ 335 | 336 | phalpha_a = HyperParams.halpha_a + p * 0.5 337 | phalpha_b = HyperParams.halpha_b + 0.5 * jnp.diagonal(WtW) 338 | 339 | Ealpha = phalpha_a / phalpha_b 340 | Elog_alpha = scp.digamma(phalpha_a) - jnp.log(phalpha_b) 341 | 342 | return AlphaState(phalpha_a, phalpha_b, Ealpha, Elog_alpha) 343 | 344 | 345 | # @jit 346 | def get_aux(pZ_m, pZ_var, pW_m, pW_var, EWtW, Etau, sampleN): 347 | n, k = pZ_m.shape 348 | p, _ = pW_m.shape 349 | 350 | ## 1) find b 351 | psi_n = Etau * p * pW_var 352 | # # !! this can be simplified 353 | Psi = jnp.broadcast_to(psi_n[jnp.newaxis, ...], (n, k, k)) * sampleN.reshape( 354 | (n, 1, 1) 355 | ) + jnp.eye(k) 356 | Psi_Z = (Psi @ pZ_m.reshape((n, k, 1))).squeeze(-1) 357 | b = jnpla.inv(jnp.sum(Psi, axis=0)) @ jnp.sum(Psi_Z, axis=0) 358 | 359 | ## 2) find R 360 | EZtZ = jnp.sum(pZ_var, axis=0) + pZ_m.T @ pZ_m 361 | ## use jnpla.eigh() due to gpu end complains about jnpla.eig() 362 | ## the result should be close subject to different ordering 363 | Lambda2, U = jnpla.eigh(EZtZ / n, symmetrize_input=False) 364 | U_weight = U * jnp.sqrt(Lambda2) 365 | 366 | quad_W = U_weight.T @ EWtW @ U_weight 367 | _, V = jnpla.eigh(quad_W, symmetrize_input=False) 368 | 369 | R = U_weight @ V 370 | R_inv = V.T / jnp.sqrt(Lambda2) @ U.T 371 | 372 | return b, R, R_inv 373 | 374 | 375 | # @jit 376 | def pjoint_main( 377 | pZ_m, pZ_var, pW_m, pW_var, EWtW, pMu_m, phalpha_a, Etau, sampleN, b, R, R_inv 378 | ): 379 | # jointly transform latent space 380 | n, k = pZ_m.shape 381 | p, _ = pW_m.shape 382 | 383 | ## 1) remove bias 384 | pZ_m_center = pZ_m - b 385 | pMu_m_center = pMu_m + pW_m @ b 386 | 387 | ## 2) rotate: each row of pW_m (pxk) and each row of pZ_m (nxk) 388 | pW_m_rot = (R.T @ pW_m.T).T 389 | pW_var_rot = R.T @ pW_var @ R 390 | 391 | # rotate each of row of pZ_m (nxk) 392 | pZ_m_rot = (R_inv @ pZ_m_center.T).T 393 | pZ_var_rot = (R_inv @ pZ_var) @ R_inv.T 394 | 395 | EWtW_q = R.T @ EWtW @ R 396 | phalpha_b_rot = HyperParams.halpha_b + 0.5 * jnp.diag(EWtW_q) 397 | Ealpha_rot = phalpha_a / phalpha_b_rot 398 | Elog_alpha_rot = scp.digamma(phalpha_a.real) - jnp.log(phalpha_b_rot.real) 399 | 400 | return JointState( 401 | pZ_m_rot, 402 | pZ_var_rot, 403 | pW_m_rot, 404 | pW_var_rot, 405 | pMu_m_center, 406 | phalpha_b_rot, 407 | Ealpha_rot, 408 | Elog_alpha_rot, 409 | ) 410 | 411 | 412 | # @jit 413 | def ptau_main(mean_quad, n, p): 414 | phtau_a = HyperParams.htau_a + n * p * 0.5 415 | phtau_b = 0.5 * mean_quad + HyperParams.htau_b 416 | 417 | Etau = phtau_a / phtau_b 418 | Elog_tau = scp.digamma(phtau_a) - jnp.log(phtau_b) 419 | 420 | return TauState(phtau_a, phtau_b, Etau, Elog_tau) 421 | 422 | 423 | ## write function to call all updating function 424 | @jit 425 | def runVB(B, W_m, W_var, EWtW, Mu_m, Ealpha, Etau, sampleN, sampleN_sqrt, n, p): 426 | Z_m, Z_var = pZ_main(B, W_m, EWtW, Mu_m, Etau, sampleN, sampleN_sqrt) 427 | Mu_m, Mu_var = pMu_main(B, W_m, Z_m, Etau, sampleN, sampleN_sqrt) 428 | W_m, W_var = pW_main(B, Z_m, Z_var, Mu_m, Etau, Ealpha, sampleN, sampleN_sqrt) 429 | EWtW = p * W_var + W_m.T @ W_m 430 | 431 | phalpha_a, phalpha_b, Ealpha, Elog_alpha = palpha_main(EWtW, p) 432 | 433 | # find aux params b and R: 434 | b, R, R_inv = get_aux(Z_m, Z_var, W_m, W_var, EWtW, Etau, sampleN) 435 | Z_m, Z_var, W_m, W_var, Mu_m, phalpha_b, Ealpha, Elog_alpha = pjoint_main( 436 | Z_m, Z_var, W_m, W_var, EWtW, Mu_m, phalpha_a, Etau, sampleN, b, R, R_inv 437 | ) 438 | EWtW = p * W_var + W_m.T @ W_m 439 | 440 | mean_quad = calc_MeanQuadForm( 441 | W_m, EWtW, Z_m, Z_var, Mu_m, Mu_var, B, sampleN, sampleN_sqrt 442 | ) 443 | phtau_a, phtau_b, Etau, Elog_tau = ptau_main(mean_quad, n, p) 444 | 445 | return ( 446 | W_m, 447 | W_var, 448 | EWtW, 449 | Z_m, 450 | Z_var, 451 | Mu_m, 452 | Mu_var, 453 | phtau_a, 454 | phtau_b, 455 | phalpha_a, 456 | phalpha_b, 457 | Ealpha, 458 | Elog_alpha, 459 | Etau, 460 | Elog_tau, 461 | mean_quad, 462 | ) 463 | 464 | 465 | ## ELBO functions 466 | def KL_QW(W_m, W_var, Ealpha, Elog_alpha): 467 | p, k = W_m.shape 468 | kl_qw = -0.5 * jnp.sum( 469 | logdet(W_var) 470 | + k 471 | + jnp.sum(Elog_alpha) 472 | - jnp.trace(Ealpha * W_var) 473 | - jnp.sum(W_m * Ealpha * W_m, axis=1) 474 | ) 475 | return kl_qw 476 | 477 | 478 | def KL_QZ(Z_m, Z_var): 479 | n, k = Z_m.shape 480 | kl_qz = 0.5 * jnp.sum( 481 | batched_trace(Z_var) + jnp.sum(Z_m * Z_m, axis=1) - k - logdet(Z_var) 482 | ) 483 | return kl_qz 484 | 485 | 486 | def KL_QMu(Mu_m, Mu_var): 487 | p = Mu_m.size 488 | kl_qmu = 0.5 * ( 489 | jnp.sum(HyperParams.hbeta * Mu_var) 490 | + HyperParams.hbeta * (Mu_m.T @ Mu_m) 491 | - p 492 | - p * jnp.log(HyperParams.hbeta) 493 | - jnp.sum(jnp.log(Mu_var)) 494 | ) 495 | return kl_qmu 496 | 497 | 498 | def KL_gamma(pa, pb, ha, hb): 499 | kl_gamma = ( 500 | (pa - ha) * scp.digamma(pa) 501 | - scp.gammaln(pa) 502 | + scp.gammaln(ha) 503 | + ha * (jnp.log(pb) - jnp.log(hb)) 504 | + pa * ((hb - pb) / pb) 505 | ) 506 | return kl_gamma 507 | 508 | 509 | def KL_Qalpha(pa, pb): 510 | kl_qa = jnp.sum(KL_gamma(pa, pb, HyperParams.halpha_a, HyperParams.halpha_b)) 511 | return kl_qa 512 | 513 | 514 | def KL_Qtau(pa, pb): 515 | kl_qtau = KL_gamma(pa, pb, HyperParams.htau_a, HyperParams.htau_b) 516 | return kl_qtau 517 | 518 | 519 | @jit 520 | def elbo( 521 | W_m, 522 | W_var, 523 | Z_m, 524 | Z_var, 525 | Mu_m, 526 | Mu_var, 527 | phtau_a, 528 | phtau_b, 529 | phalpha_a, 530 | phalpha_b, 531 | Ealpha, 532 | Elog_alpha, 533 | Etau, 534 | Elog_tau, 535 | B, 536 | mean_quad, 537 | ): 538 | n, p = B.shape 539 | 540 | pD = 0.5 * (n * p * Elog_tau - Etau * mean_quad) 541 | 542 | kl_qw = KL_QW(W_m, W_var, Ealpha, Elog_alpha) 543 | kl_qz = KL_QZ(Z_m, Z_var) 544 | kl_qmu = KL_QMu(Mu_m, Mu_var) 545 | kl_qa = KL_Qalpha(phalpha_a, phalpha_b) 546 | kl_qt = KL_Qtau(phtau_a, phtau_b) 547 | elbo_sum = pD - (kl_qw + kl_qz + kl_qmu + kl_qa + kl_qt) 548 | 549 | return ( 550 | elbo_sum.real, 551 | pD.real, 552 | kl_qw.real, 553 | kl_qz.real, 554 | kl_qmu.real, 555 | kl_qa.real, 556 | kl_qt.real, 557 | ) 558 | 559 | 560 | # calculate R2 for ordered factors: exausted memory 561 | def R2(B, W_m, Z_m, Etau, sampleN_sqrt): 562 | # import pdb; pdb.set_trace() 563 | n, p = B.shape 564 | _, k = Z_m.shape 565 | 566 | tss = jnp.sum(B * B) * Etau 567 | # pxk, nxk, 568 | # resid = B.T - W_m @ (Z_m * sampleN_sqrt[:, jnp.newaxis]).T 569 | # sse = batched_trace(jnp.swapaxes(resid * Etau, -2, -1) @ resid2) 570 | 571 | # resid = B.T - batched_outer(W_m.T, (Z_m * sampleN_sqrt[:, None]).T) 572 | # sse = Etau * jnp.sum((resid * resid)) 573 | 574 | # r2 = 1.0 - sse / tss 575 | 576 | # ## save memory space: 577 | sse = jnp.zeros((k,)) 578 | for i in range(n): 579 | WZ = W_m * Z_m[i] * sampleN_sqrt[i] # pxk 580 | res = B[i][:, None] - WZ # pxk 581 | tmp = jnp.sum(res * res, axis=0) # (k,) 582 | sse += tmp 583 | r2 = 1.0 - sse * Etau / tss 584 | 585 | return r2 586 | 587 | 588 | def main(args): 589 | argp = ap.ArgumentParser(description="") # create an instance 590 | argp.add_argument("Zscore_path") 591 | argp.add_argument("N_path") 592 | argp.add_argument( 593 | "-k", type=int, default=10 594 | ) # "-" must only has one letter like "-k", not like "-knum" 595 | argp.add_argument( 596 | "--elbo-tol", 597 | default=1e-3, 598 | type=float, 599 | help="Tolerance for change in ELBO to halt inference", 600 | ) 601 | argp.add_argument( 602 | "--tau-tol", 603 | default=1e-6, 604 | type=float, 605 | help="Tolerance for change in residual variance to halt inference", 606 | ) 607 | argp.add_argument( 608 | "--hyper", 609 | default=None, 610 | nargs="+", 611 | type=float, 612 | help="Input hyperparameter in order for alpha, tau, and beta; Default hyperparameters are 1e-3", 613 | ) 614 | argp.add_argument( 615 | "--max-iter", 616 | default=10000, 617 | type=int, 618 | help="Maximum number of iterations to learn parameters", 619 | ) 620 | argp.add_argument( 621 | "--init-factor", 622 | choices=["random", "svd", "zero"], 623 | default="random", 624 | help="How to initialize the latent factors and weights", 625 | ) 626 | argp.add_argument( 627 | "--removeN", 628 | action="store_true", 629 | help="remove scalar N from model, i.e. set all N==1", 630 | default=False, 631 | ) 632 | argp.add_argument( 633 | "--scaledat", 634 | action="store_true", 635 | default=True, 636 | help="scale each SNPs effect across traits (Default=True)", 637 | ) 638 | # argp.add_argument( 639 | # "--noaux", 640 | # action="store_true", 641 | # help="remove aux parameter (slow version)", 642 | # ) 643 | argp.add_argument( 644 | "--rate", 645 | default=250, 646 | type=int, 647 | help="Rate of printing elbo info; default is printing per 250 iters", 648 | ) 649 | argp.add_argument("-p", "--platform", choices=["cpu", "gpu"], default="cpu") 650 | argp.add_argument( 651 | "-s", "--seed", type=int, default=123456789, help="Seed for randomization." 652 | ) 653 | argp.add_argument("-d", "--debug", action="store_true", default=False) 654 | argp.add_argument("-v", "--verbose", action="store_true", default=False) 655 | argp.add_argument( 656 | "-o", "--output", type=str, default="VBres", help="Prefix path for output" 657 | ) 658 | 659 | args = argp.parse_args(args) # a list a strings 660 | 661 | log = get_logger(__name__, args.output) 662 | if args.verbose: 663 | log.setLevel(logging.DEBUG) 664 | else: 665 | log.setLevel(logging.INFO) 666 | 667 | # setup to use either CPU (default) or GPU 668 | set_platform(args.platform) 669 | 670 | # ensure 64bit precision (default use 32bit) 671 | jax.config.update("jax_enable_x64", True) 672 | 673 | # init key (for jax) 674 | key = random.PRNGKey(args.seed) 675 | key, key_init = random.split(key, 2) # split into 2 chunk 676 | 677 | log.info("Loading GWAS effect size and standard error.") 678 | B, sampleN, sampleN_sqrt = read_data( 679 | args.Zscore_path, args.N_path, log, args.removeN, args.scaledat 680 | ) 681 | log.info("Finished loading GWAS effect size, sample size and standard error.") 682 | 683 | n_studies, p_snps = B.shape 684 | log.info(f"Found N = {n_studies} studies, P = {p_snps} SNPs") 685 | 686 | # number of factors 687 | k = args.k 688 | log.info(f"User set K = {k} latent factors.") 689 | 690 | # set optionas for stopping rule 691 | options = Options(args.elbo_tol, args.tau_tol, args.max_iter) 692 | 693 | # set 5 hyperparameters: otherwise use default 1e-3 694 | if args.hyper is not None: 695 | HyperParams.halpha_a = float(args.hyper[0]) 696 | HyperParams.halpha_b = float(args.hyper[1]) 697 | HyperParams.htau_a = float(args.hyper[2]) 698 | HyperParams.htau_b = float(args.hyper[3]) 699 | HyperParams.hbeta = float(args.hyper[4]) 700 | log.info( 701 | f"set parameters {HyperParams.halpha_a},{HyperParams.halpha_b},{HyperParams.htau_a},{HyperParams.htau_b}, {HyperParams.hbeta} " 702 | ) 703 | 704 | # set initializers 705 | log.info(f"Initalizing mean parameters with seed {args.seed}.") 706 | (W_m, W_var, Mu_m, Ealpha, Etau) = get_init( 707 | key_init, n_studies, p_snps, k, B, log, args.init_factor 708 | ) 709 | EWtW = p_snps * W_var + W_m.T @ W_m 710 | phalpha_a = HyperParams.halpha_a 711 | phalpha_b = HyperParams.halpha_b 712 | log.info("Completed initalization.") 713 | 714 | f_finfo = jnp.finfo(float) # Machine limits for floating point types 715 | oelbo, delbo = f_finfo.min, f_finfo.max 716 | otau, dtau = 1000, 1000 # initial value for delta tau 717 | 718 | log.info( 719 | "Starting Variational inference (first iter may be slow due to JIT compilation)." 720 | ) 721 | RATE = args.rate # print per 250 iterations 722 | for idx in range(options.max_iter): 723 | 724 | ( 725 | W_m, 726 | W_var, 727 | EWtW, 728 | Z_m, 729 | Z_var, 730 | Mu_m, 731 | Mu_var, 732 | phtau_a, 733 | phtau_b, 734 | phalpha_a, 735 | phalpha_b, 736 | Ealpha, 737 | Elog_alpha, 738 | Etau, 739 | Elog_tau, 740 | mean_quad, 741 | ) = runVB( 742 | B, 743 | W_m, 744 | W_var, 745 | EWtW, 746 | Mu_m, 747 | Ealpha, 748 | Etau, 749 | sampleN, 750 | sampleN_sqrt, 751 | n_studies, 752 | p_snps, 753 | ) 754 | check_elbo, pD, kl_qw, kl_qz, kl_qmu, kl_qa, kl_qt = elbo( 755 | W_m, 756 | W_var, 757 | Z_m, 758 | Z_var, 759 | Mu_m, 760 | Mu_var, 761 | phtau_a, 762 | phtau_b, 763 | phalpha_a, 764 | phalpha_b, 765 | Ealpha, 766 | Elog_alpha, 767 | Etau, 768 | Elog_tau, 769 | B, 770 | mean_quad, 771 | ) 772 | 773 | delbo = check_elbo - oelbo 774 | oelbo = check_elbo 775 | if idx % RATE == 0: 776 | log.info( 777 | # f"itr = {idx} | Elbo = {check_elbo} | deltaElbo = {delbo} | Tau = {Etau}" 778 | f"itr = {idx} | Elbo = {check_elbo} | deltaElbo = {delbo} | Tau = {Etau}| pD = {pD}|kl_qw={kl_qw}|kl_qz={kl_qz}|kl_qmu={kl_qmu}|kl_qa={kl_qa}|kl_qt={kl_qt}" 779 | ) 780 | # W_m.block_until_ready() 781 | # jax.profiler.save_device_memory_profile(f"testres/testmemory{idx}.prof") 782 | 783 | dtau = Etau - otau 784 | otau = Etau 785 | 786 | if jnp.fabs(delbo) < args.elbo_tol: 787 | break 788 | 789 | r2 = R2(B, W_m, Z_m, Etau, sampleN_sqrt) 790 | 791 | f_order = jnp.argsort(-jnp.abs(r2)) 792 | ordered_r2 = r2[f_order] 793 | ordered_Z_m = Z_m[:, f_order] 794 | ordered_W_m = W_m[:, f_order] 795 | 796 | f_info = np.column_stack((jnp.arange(k) + 1, Ealpha.real[f_order], ordered_r2)) 797 | 798 | log.info(f"Finished inference after {idx} iterations.") 799 | log.info(f"Final elbo = {check_elbo} and resid precision = {Etau}") 800 | log.info(f"Final sorted Ealpha = {jnp.sort(Ealpha)}") 801 | log.info(f"Final sorted R2 = {ordered_r2}") 802 | 803 | log.info("Writing results.") 804 | np.savetxt(f"{args.output}.Zm.tsv.gz", ordered_Z_m.real, fmt="%s", delimiter="\t") 805 | np.savetxt(f"{args.output}.Wm.tsv.gz", ordered_W_m.real, fmt="%s", delimiter="\t") 806 | np.savetxt(f"{args.output}.factor.tsv.gz", f_info, fmt="%s", delimiter="\t") 807 | 808 | # # calculate E(W^2) [unordered]: W_m pxk, W_var kxk 809 | # EW2 = W_m ** 2 + jnp.diagonal(W_var) 810 | # # calculate E(Z^2) [unordered]: Z_m nxk, Z_var nxkxk 811 | # EZ2 = np.zeros((n_studies, k)) 812 | # Z_m2 = Z_m ** 2 813 | # for i in range(n_studies): 814 | # EZ2[i] = Z_m2[i] + jnp.diagonal(Z_var[i]) 815 | # 816 | # ordered_EW2 = EW2[:, f_order] 817 | # ordered_EZ2 = EZ2[:, f_order] 818 | 819 | ordered_W_var = jnp.diagonal(W_var)[f_order] 820 | 821 | Z_var_diag = np.zeros((n_studies, k)) 822 | for i in range(n_studies): 823 | Z_var_diag[i] = jnp.diagonal(Z_var[i]) 824 | ordered_Z_var_diag = Z_var_diag[:, f_order] 825 | 826 | # np.savetxt(f"{args.output}.EW2.tsv.gz", ordered_EW2.real, fmt="%s", delimiter="\t") 827 | # np.savetxt(f"{args.output}.EZ2.tsv.gz", ordered_EZ2.real, fmt="%s", delimiter="\t") 828 | np.savetxt( 829 | f"{args.output}.Wvar.tsv.gz", ordered_W_var.real, fmt="%s", delimiter="\t" 830 | ) 831 | np.savetxt( 832 | f"{args.output}.Zvar.tsv.gz", ordered_Z_var_diag.real, fmt="%s", delimiter="\t" 833 | ) 834 | 835 | log.info("Finished. Goodbye.") 836 | 837 | # check_elbo.block_until_ready() 838 | # jax.profiler.save_device_memory_profile("testres/testmemory.prof") 839 | 840 | return 0 841 | 842 | 843 | # user call this script will treat it like a program 844 | if __name__ == "__main__": 845 | sys.exit( 846 | main(sys.argv[1:]) 847 | ) # grab all arguments; first arg is alway the name of the script 848 | --------------------------------------------------------------------------------