├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── examples ├── jobscript.sh ├── score_variants.crc.sh ├── score_variants.neuro.sh └── score_variants.sherlock.neuro.sh ├── pyproject.toml ├── scripts └── get_caqtl_data.sh ├── src ├── __init__.py ├── archive │ ├── old.variant_scoring.per_chrom.shuf_pvals.py │ ├── old_chunks.py │ └── variant_scoring.per_chunk.py ├── generators │ ├── __init__.py │ ├── peak_generator.py │ └── variant_generator.py ├── hitcaller_variant.py ├── utils │ ├── __init__.py │ ├── argmanager.py │ ├── helpers.py │ ├── io.py │ ├── losses.py │ ├── one_hot.py │ └── shap_utils.py ├── variant_annotation.py ├── variant_scoring.per_chrom.py ├── variant_scoring.py ├── variant_shap.py └── variant_summary_across_folds.py └── tests ├── archive ├── annotations │ └── test.annotations.tsv ├── test.annotations.sh ├── test.forward_only.sh ├── test.per_chrom.forward_only.sh ├── test.per_chrom.sh └── test.sh ├── conftest.py ├── data ├── caqtls.african.lcls.benchmarking.subset.tsv ├── test.anno_input.tsv ├── test.bed ├── test.chrombpnet.incorrect.tsv ├── test.chrombpnet.incorrect2.tsv ├── test.chrombpnet.no_chr.tsv ├── test.chrombpnet.tsv ├── test.genes.bed ├── test.hits.bed ├── test.incorrect.bed ├── test.original.tsv ├── test.peaks.bed └── test.plink.tsv ├── test_load_bed_files.py ├── test_load_variant_table.py ├── test_one_hot.py ├── test_variant_annotation.py └── test_variant_scoring.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Pytest 2 | .pytest_cache/ 3 | __pycache__/ 4 | *.pyc 5 | *.pyo 6 | *.pyd 7 | .coverage 8 | htmlcov/ 9 | .tox/ 10 | .cache 11 | nosetests.xml 12 | coverage.xml 13 | *.cover 14 | .hypothesis/ 15 | 16 | # Test outputs 17 | tests/outputs/ 18 | tests/tmp/ 19 | tests/data/raw 20 | 21 | .vscode 22 | test/output 23 | 24 | */.ipynbs_checkpoints 25 | */*/.ipynb_checkpoints 26 | */*/*/.ipynb_checkpoints 27 | tests/archive/.ipynb_checkpoints 28 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | 2 | ## Test data 3 | 4 | Test data is derived from the Africa caQTLs associated with the ChromBPNet preprint 5 | (Pampari et al, biorxiv 2024). The variants are derived from the dataset 6 | on Synapse at https://www.synapse.org/Synapse:syn64126781. 7 | 8 | Download and processing of these variants to prepare the test data is documented 9 | at `scripts/get_caqtl_data.sh`. 10 | 11 | 12 | ## Unit testing 13 | 14 | Unit testing is set up with `pytest`. 15 | 16 | Some of the tests depend on ChromBPNet models or genome references stored on Oak at 17 | `${OAK}/projects/variant-scorer-test`. 18 | 19 | For example, to run the tests on Sherlock, request an interactive node with a GPU: 20 | 21 | ```bash 22 | sh_dev -g 1 -t 120 23 | ``` 24 | 25 | Activate your associated conda environment for the `variant-scorer` repo. (Install `pytest` 26 | there if needed.) 27 | 28 | ```bash 29 | conda activate variant-scorer 30 | pip install pytest 31 | ``` 32 | 33 | Check the output of your `OAK` variable: 34 | 35 | ```bash 36 | echo $OAK 37 | ``` 38 | 39 | Run the tests: 40 | 41 | ```bash 42 | pytest -rs -s 43 | ``` 44 | 45 | Optionally, to run without a GPU, use: 46 | 47 | ```bash 48 | pytest -rs -s -m "not gpu" 49 | ``` 50 | 51 | Or, to skip all the tests that require Oak data, use: 52 | 53 | ```bash 54 | pytest -rs -s -m "not oak" 55 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kundaje Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # variant-scorer 2 | 3 | The variant scoring repository provides a set of scripts for scoring genetic variants using a ChromBPNet model. 4 | 5 | **Important notes:** 6 | 7 | - in the input variant list, the `pos` (position) column is expected to be the 1-indexed SNP position, unless the schema is *bed* 8 | - the reported log fold-change (`logFC`) for predicted variant effects is in log base 2 9 | - by default, counts and profile prediction for each allele are averaged between the predictions obtained using 10 | the forward sequence and the reverse-complement of that sequence as input. This can be disabled using 11 | the `--forward_only` option to only use the forward sequence predictions (See [issue28](https://github.com/kundajelab/variant-scorer/issues/28#issuecomment-2900574336) 12 | for a discussion). 13 | 14 | 15 | # Variant inputs 16 | 17 | ## Variant file schemas 18 | 19 | Variant lists should be provided as TSVs with one variant per row, and column 20 | names adhering to one of the following schemas: 21 | 22 | * chrombpnet : `['chr', 'pos', 'allele1', 'allele2', 'variant_id']` 23 | * bed : `['chr', 'pos', 'end', 'allele1', 'allele2', 'variant_id']` 24 | * plink : `['chr', 'variant_id', 'ignore1', 'pos', 'allele1', 'allele2']` 25 | * original : `['chr', 'pos', 'variant_id', 'allele1', 'allele2']` 26 | 27 | **NOTE:** The `pos` (position) column is expected to correspond to the 1-indexed variant position, unless the schema is `bed`. 28 | 29 | ## Specifying variants 30 | 31 | For single-nucleotide variants, `allele1` and `allele2`, provide the corresponding nucleotide for each allele, 32 | e.g. (for variants provided in the `chrombpnet` schema): 33 | 34 | ``` 35 | chr1 866281 C T 1_866281_C_T 36 | ``` 37 | 38 | For deletions, use `-` for `allele2` to represent the deleted nucleotides, e.g. 39 | 40 | ``` 41 | chr1 866281 C - 1_866281_Cdel 42 | ``` 43 | 44 | For insertions, use `-` for `allele1` to represent the inserted nucleotides, e.g. 45 | 46 | ``` 47 | chr1 866281 - CT 1_866281_CTins 48 | ``` 49 | 50 | 51 | 52 | # Workflow 53 | 54 | ## 1. Score variants: `variant_scoring.py` 55 | 56 | This script takes a list of variants in various input formats and generates scores 57 | for the variants using a ChromBPNet model. The output is a TSV file containing the scores for each variant. 58 | Since variants are stored in memory, we also provide `variant_scoring.per_chrom.py` to score variants on a per-chromosome basis, 59 | and write the scores per chromosome to file before proceeding to the next chromosome. Per-chromosome 60 | files can then be merged automatically using the `--merge` option. 61 | 62 | ### Usage: 63 | 64 | 65 | ```bash 66 | python src/variant_scoring.py --list [VARIANTS_FILE] \ 67 | --genome [GENOME_FASTA] \ 68 | --model [MODEL_PATH] \ 69 | --out_prefix [OUT_PREFIX] \ 70 | --chrom_sizes [CHROM_SIZES] \ 71 | [OTHER_ARGS] 72 | ``` 73 | 74 | ### Input arguments: 75 | 76 | 77 | - `-h`, `--help`: Show help message with arguments and their descriptions, and exit 78 | - `-l`, `--list` (**required**): Path to TSV file containing a list of variants to score 79 | - `-g`, `--genome` (**required**): Path to the genome FASTA 80 | - `-pg`, `--peak_genome`: Path to the genome FASTA for peaks 81 | - `-m`, `--model` (**required**): Path to the ChromBPNet model .h5 file to use for variant scoring. For most use cases, this should be the bias-corrected model (chrombpnet_nobias.h5) 82 | - `-o`, `--out_prefix` (**required**): Output prefix for storing SNP effect score predictions from the script, in the form of `/`. Directory should already exist. 83 | - `-s`, `--chrom_sizes` (**required**): Path to TSV file with chromosome sizes 84 | - `--no_hdf5`: Do not save basepair resolution predictions to hdf5 file. Recommended for large variant lists. 85 | - `-ps`, `--peak_chrom_sizes`: Path to TSV file with chromosome sizes for peak genome 86 | - `-b`, `--bias`: Bias model to use for variant scoring 87 | - `-li`, `--lite`: Models were trained with chrombpnet-lite 88 | - `-dm`, `--debug_mode`: Display allele input sequences 89 | - `-bs`, `--batch_size`: Batch size to use for the model 90 | - `-sc`, `--schema`: Format for the input variants TSV file. Choices: `bed`, `plink`, `chrombpnet`, `original` 91 | - `-p`, `--peaks`: Path to BED file containing peak regions 92 | - `-n`, `--num_shuf`: Number of shuffled scores per SNP 93 | - `-t`, `--total_shuf`: Total number of shuffled scores across all SNPs. Overrides `--num_shuf` 94 | - `-mp`, `--max_peaks`: Maximum number of peaks to use for peak percentile calculation 95 | - `-c`, `--chrom`: Only score SNPs in selected chromosome 96 | - `-r`, `--random_seed`: Random seed for reproducibility when sampling 97 | - `--no_hdf5`: Do not save detailed predictions in hdf5 file 98 | - `-nc`, `--num_chunks`: Number of chunks to divide SNP file into 99 | - `-fo`, `--forward_only`: Run variant scoring only on forward sequence (Default: False) 100 | - `-st`, `--shap_type`: ChromBPNet output for which SHAP values should be computed (`counts` or `profile`). Default is `counts` 101 | - `-sh`, `--shuffled_scores`: Path to pre-computed shuffled scores 102 | 103 | 104 | 105 | ### Outputs: 106 | 107 | The variant scores are stored in `.variant_scores.tsv`. 108 | 109 | Predicted effects are computed as `allele2` vs `allele1`. For each variant, we 110 | compute the following metrics, as described in the [ChromPBNet preprint](https://www.biorxiv.org/content/10.1101/2024.12.25.630221v1.full.pdf+html): 111 | 112 | - `logfc`: Log (base 2) fold-change of total predicted coverage for `allele2` vs `allele1`, providing a canonical effect size of the variant on local accessibility. A higher `logFC` indicates higher predicted accessibility for `allele2` compared to `allele1`. 113 | - `abs_logfc`: Absolute value of the log fold-change. 114 | - `active_allele_quantile`: Active Allele Quantile is the percentile of the predicted total coverage of the stronger allele relative to the distribution of predicted total coverage across all ATAC-seq/DNase-seq peaks. 115 | - `jsd`: Jensen-Shannon distance between the bias-corrected base-resolution probability profiles of the two alleles, which captures effects on profile shape, such as changes in TF footprints. 116 | 117 | We provide several additional metrics that are computed as the product of the above metrics: 118 | 119 | - `abs_logfc_x_jsd`: described in the preprint as Integrative Effect Size (IES), the product of logFC and JSD, 120 | - `logfc_x_active_allele_quantile` 121 | - `abs_logfc_x_active_allele_quantile` 122 | - `jsd_x_active_allele_quantile` 123 | - `logfc_x_jsd_x_active_allele_quantile`: described in the preprint as Integrative Prioritization Score (IPS) is the product of logFC, JSD, and AAQ 124 | - `abs_logfc_x_jsd_x_active_allele_quantile` 125 | 126 | *__NOTE__*: For profile predictions, the saved arrays consist of model logits, not probabilities. This allows for averaging profile predictions across folds more easily, by averaging logits over folds and then taking the softmax (see [`variant-scorer/pull/23`](https://github.com/kundajelab/variant-scorer/pull/23)). 127 | 128 | 129 | 130 | 131 | ## 2. Summarize variant scores across model folds: `variant_summary_across_folds.py` 132 | 133 | This script takes variant scores generated by the `variant_scoring.py` script for several model folds, 134 | and generates a TSV file with the mean scores for each score metric across folds. 135 | 136 | ### Usage: 137 | 138 | ```bash 139 | python src/variant_summary_across_folds.py \ 140 | --score_dir [VARIANT_SCORE_DIR] \ 141 | --score_list [SCORE_LIST] \ 142 | --out_prefix [OUT_PREFIX] \ 143 | --schema [SCHEMA] 144 | ``` 145 | 146 | ### Input arguments: 147 | 148 | 149 | - `-h`, `--help`: Show help message with arguments and their descriptions, and exit 150 | - `-sd`, `--score_dir` (**required**): Path to directory containing variant scores that will be used to generate summary 151 | - `-sl`, `--score_list` (**required**): Space-separated list of variant score file names that will be used to generate summary. Files should exist in `--score_dir`. 152 | - `-o`, `--out_prefix` (**required**): Output prefix for storing the summary file with average scores across folds, in the form of `/`. Directory should already exist. 153 | - `-sc`, `--schema`: Format for the input variants list. Choices: `bed`, `plink`, `plink2`, `chrombpnet`, `original`. Default is `chrombpnet`. 154 | 155 | 156 | ### Outputs: 157 | 158 | The summary file is stored at `.mean.variant_scores.tsv`. 159 | 160 | 161 | ## 3. Annotate variants: `variant_annotation.py` 162 | 163 | This script takes a list of variants and annotates each with their closest genes, 164 | and/or overlaps with peaks or motif hits. 165 | 166 | **NOTE:** This script assumes that the genes, peaks, and hits are in the same reference genome as the variants, and it does not perform any liftOver operations. 167 | 168 | ### Usage: 169 | 170 | ```bash 171 | python src/variant_annotation.py \ 172 | --list [VARIANT_SCORES or VARIANT_LIST] \ 173 | --out_prefix [OUT_PREFIX] \ 174 | --peaks [PEAKS] \ 175 | --genes [GENES] \ 176 | --hits [HITS] \ 177 | --schema [SCHEMA] 178 | ``` 179 | 180 | ### Input arguments: 181 | 182 | 183 | - `-h`, `--help`: Show help message with arguments and their descriptions, and exit 184 | - `-l`, `--list` (**required**): Path to TSV file containing scored variants as output by `variant_scoring.py` (or the summary across folds), or a BED file of variants with the `--schema bed` option. 185 | - `-o`, `--out_prefix` (**required**): Output prefix for storing the annotated file, in the form of `/`. Directory should already exist. 186 | - `-sc`, `--schema`: Format for the input variants list. Use `bed` if providing BED file of variants. 187 | - `-ge`, `--genes`: Path to BED file containing gene regions. If provided, the script will annotate each variant with the three closest genes and the distance to each. 188 | - `-p`, `--peaks`: Path to BED file containing peak regions. If provided, the script will annotate each variant according to whether it overlaps with any peak. 189 | - `--hits`: Path to BED file containing motif hits, with columns `chr`, `start`, `end`, `motif`, `score`, `strand`, `class`. If provided, the script will annotate variants with whether they overlap any motif hits (and which they overlap). 190 | 191 | At least one of `--genes`, `--peaks`, or `--hits` must be provided for annotation. 192 | 193 | ## 4. Compute variant SHAP scores: `variant_shap.py` 194 | 195 | This script computes the contribution scores for each variant, for allele1 196 | and allele2, with respect to the specified ChromBPNet model output (`counts` or `profile`). 197 | 198 | ```bash 199 | python src/variant_shap.py \ 200 | --list [VARIANTS_FILE] \ 201 | --genome [GENOME] \ 202 | --chrom_sizes [CHROM_SIZES] \ 203 | --model [MODEL_PATH] \ 204 | --out_prefix [OUT_PREFIX] \ 205 | --schema [SCHEMA] \ 206 | --shap_type [SHAP_TYPE] \ 207 | [OTHER_ARGS] 208 | ``` 209 | 210 | ### Input arguments: 211 | 212 | - `-h`, `--help`: Show help message with arguments and their descriptions, and exit 213 | - `-l`, `--list` (**required**): A TSV file containing a list of variants to score 214 | - `-g`, `--genome` (**required**): Path to genome FASTA 215 | - `-m`, `--model` (**required**): Path to the ChromBPNet model .h5 file to use for variant scoring. For most use cases, this should be the bias-corrected model (chrombpnet_nobias.h5) 216 | - `-o`, `--out_prefix` (**required**): Output prefix for storing SNP effect score predictions from the script, in the form of `/`. Directory should already exist. 217 | - `-s`, `--chrom_sizes` (**required**): Path to TSV file with chromosome sizes 218 | - `-li`, `--lite`: Models were trained with chrombpnet-lite 219 | - `-dm`, `--debug_mode`: Display allele input sequences 220 | - `-bs`, `--batch_size`: Batch size to use for the model. Default is 10000 221 | - `-sc`, `--schema`: Format for the input variants list. Choices: `bed`, `plink`, `chrombpnet`, `original`. Default is `chrombpnet` 222 | - `-c`, `--chrom`: Only score SNPs in selected chromosome 223 | - `-st`, `--shap_type`: ChromBPNet output for which SHAP values should be computed. Can specify multiple values. Default is `counts` 224 | 225 | 226 | ### Outputs: 227 | 228 | The variant SHAP scores are stored in `.variant_shap..h5`. 229 | 230 | The h5 file contains the following datasets: 231 | 232 | - `alleles`: shape `(2 * number of variants, )`, binary array indicating whether the allele is allele1 (0) or allele2 (1) 233 | - `raw/seq`: shape `(2 * number of variants, 4, 2114)`, contains one hot encoding of sequences around the variant which were scored 234 | - `shap/seq`: shape `(2 * number of variants, 4, 2114)`, contains hypothetical contribution scores 235 | - `projected_shap/seq`: shape `(2 * number of variants, 4, 2114)`, contains values obtained by multiplying hypothetical SHAP values with raw (one-hot encoded) sequences 236 | - `variant_ids`: shape `(2 * number of variants, )`, contains variant identifiers corresponding to each variant as provided in the input list -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/variant-scorer/0e1e34199e63112aa618748bb79a206fc491300a/__init__.py -------------------------------------------------------------------------------- /examples/jobscript.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | module load cudnn 4 | 5 | # https://stackoverflow.com/questions/34534513/calling-conda-source-activate-from-bash-script 6 | eval "$(conda shell.bash hook)" 7 | conda activate chrombpnet 8 | 9 | echo "Live" 10 | python "$@" 11 | 12 | -------------------------------------------------------------------------------- /examples/score_variants.crc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 4 | 5 | list=/oak/stanford/groups/akundaje/soumyak/refs/plink_1kg_hg38/all.1000G.EUR.QC.bim 6 | genome=/oak/stanford/groups/akundaje/refs/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta 7 | sizes=/oak/stanford/groups/akundaje/refs/hg38/hg38.chrom.sizes 8 | model=/mnt/lab_data3/soumyak/CRC_finemap/output/chrombpnet/20220124_celltype_models_5foldCV_Myo2bias/Cancer_Associated_Fibroblasts/fold0/Cancer_Associated_Fibroblasts.h5 9 | bias=/mnt/lab_data3/soumyak/CRC_finemap/output/chrombpnet/bias_models/Myofibroblasts_2/Myofibroblasts_2.bias.2114.1000.h5 10 | out_prefix=/mnt/lab_data3/soumyak/variant-scorer/examples/crc.caf_fold0.all.1000G.EUR 11 | 12 | time python ../src/variant_scoring.py -l $list \ 13 | -g $genome \ 14 | -s $sizes \ 15 | -m $model \ 16 | -o $out_prefix \ 17 | -sc plink \ 18 | -dm \ 19 | -li 20 | 21 | -------------------------------------------------------------------------------- /examples/score_variants.neuro.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 4 | 5 | list=/oak/stanford/groups/akundaje/soumyak/refs/plink_1kg_hg38/all.1000G.EUR.QC.bim 6 | genome=/oak/stanford/groups/akundaje/refs/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta 7 | model=/srv/scratch/soumyak/neuro-variants/outs/1_31_2022_adpd_model_training/full_models/cluster1_fold0/chrombpnet_wo_bias.h5 8 | bias=/srv/scratch/soumyak/neuro-variants/outs/1_31_2022_adpd_model_training/full_models/cluster1_fold0/chrombpnet.h5 9 | out_prefix=/mnt/lab_data3/soumyak/variant-scorer/examples/neuro.cluster1_fold0.all.1000G.EUR 10 | 11 | time python ../src/variant_scoring.py -l $list \ 12 | -g $genome \ 13 | -m $model \ 14 | -o $out_prefix \ 15 | -sc plink \ 16 | -b $bias \ 17 | -dm 18 | 19 | -------------------------------------------------------------------------------- /examples/score_variants.sherlock.neuro.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -o pipefail 5 | set -u 6 | 7 | list=/oak/stanford/groups/akundaje/soumyak/refs/plink_1kg_hg38/all.1000G.EUR.QC.bim genome=/oak/stanford/groups/akundaje/refs/hg38/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta 8 | model=/oak/stanford/groups/akundaje/projects/neuro-variants/outs/1_31_2022_adpd_model_training/full_models/cluster1_fold0/chrombpnet_wo_bias.h5 9 | bias=/oak/stanford/groups/akundaje/projects/neuro-variants/outs/1_31_2022_adpd_model_training/full_models/cluster1_fold0/chrombpnet.h5 10 | out_prefix=/oak/stanford/groups/akundaje/projects/variant-scorer/examples/neuro.cluster1_fold0.all.1000G.EUR 11 | 12 | JOBSCRIPT=/home/groups/akundaje/soumyak/variant-scorer/examples/jobscript.sh 13 | job=neuro.cluster1_fold0 14 | 15 | sbatch -J $job \ 16 | -t 60 -c 2 --mem=20G -p akundaje,gpu --gpus 1 \ 17 | -o /oak/stanford/groups/akundaje/projects/variant-scorer/examples/$job.log.txt \ 18 | -e /oak/stanford/groups/akundaje/projects/variant-scorer/examples/$job.err.txt \ 19 | $JOBSCRIPT /home/groups/akundaje/soumyak/variant-scorer/src/variant_scoring.py \ 20 | -l $list \ 21 | -g $genome \ 22 | -m $model \ 23 | -o $out_prefix \ 24 | -sc plink \ 25 | -b $bias \ 26 | -dm 27 | 28 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | markers = [ 3 | "gpu: tests that require GPU hardware", 4 | "oak: tests required to be run on Oak" 5 | ] 6 | -------------------------------------------------------------------------------- /scripts/get_caqtl_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | # Selin Jessa 4 | # July 2025 5 | 6 | # Purpose: Documenting the process of setting up caQTL data for package testing. These data 7 | # are provided with the ChromBPNet manuscript (Pampari et al, biorxiv, 2024). 8 | # These variants provided are 1-based in hg38. 9 | 10 | # 1. Download the caQTL data (https://www.synapse.org/Synapse:syn64126781) 11 | # Requires synapse client & authentication first. 12 | if [[ ! -f ../tests/data/raw/caqtls.african.lcls.benchmarking.all.tsv ]]; then 13 | echo "Downloading caQTL data from Synapse..." 14 | synapse get syn64126781 15 | gunzip caqtls.african.lcls.benchmarking.all.tsv.gz 16 | mv caqtls.african.lcls.benchmarking.all.tsv ../tests/data/raw/ 17 | else 18 | echo "caQTL data already downloaded." 19 | fi 20 | 21 | TEST_DIR="../tests/data" 22 | 23 | # 2. Set rsids to keep 24 | RSID="rs7417106|rs3121577|rs2465131|rs2488995|rs2488996|rs7527973|rs2063455|rs2685245|rs4402801|rs4854274" 25 | 26 | # 3. Get appropriate columns and filter to a selected variants 27 | # 1 var.chr : Chromosome of the variant (GRCh38) 28 | # 2 var.pos_hg38 : Position of the variant (GRCh38, 1-based) 29 | # 3 var.allele1 : Allele 1 for the variant 30 | # 4 var.allele2 : Allele 2 for the variant 31 | # 5 var.isused : True if variant is used in final ChromBPNet benchmarking 32 | # 27 pred.chrombpnet.encsr637xsc.varscore.logfc : ChromBPNet logFC predictions in encid encsr637xsc 33 | # 28 pred.chrombpnet.encsr637xsc.varscore.jsd : ChromBPNet JSD predictions in encid encsr637xsc 34 | # 33 var.snp_id : variant identifier 1 35 | # 36 var.dbsnp_rsid : dbSNP rsid identifier 36 | cut -f 1,2,3,4,5,27,28,33,36 ${TEST_DIR}/raw/caqtls.african.lcls.benchmarking.all.tsv \ 37 | | awk -F'\t' -v rsid_regex="$RSID" 'BEGIN{OFS="\t"} {if (NR==1 || $9 ~ rsid_regex) print $1,$2,$3,$4,$5,$6,$7,$8,$9}' \ 38 | > ${TEST_DIR}/caqtls.african.lcls.benchmarking.subset.tsv 39 | 40 | # 3. Convert to various input formats for testing 41 | 42 | # BED (0-based: ['chr', 'pos', 'end', 'allele1', 'allele2', 'variant_id']) 43 | awk 'BEGIN{OFS="\t"} {print $1,$2-1,$2,$3,$4,$9}' ${TEST_DIR}/caqtls.african.lcls.benchmarking.subset.tsv \ 44 | | tail -n+2 \ 45 | > ${TEST_DIR}/test.bed 46 | 47 | # ChromBPNet input format (1-based, ['chr', 'pos', 'allele1', 'allele2', 'variant_id']) 48 | awk 'BEGIN{OFS="\t"} {print $1,$2,$3,$4,$9}' ${TEST_DIR}/caqtls.african.lcls.benchmarking.subset.tsv \ 49 | | tail -n+2 \ 50 | > ${TEST_DIR}/test.chrombpnet.tsv 51 | 52 | # ChromBPNet format missing 'chr' prefix 53 | sed 's/^chr//g' ${TEST_DIR}/test.chrombpnet.tsv > ${TEST_DIR}/test.chrombpnet.no_chr.tsv 54 | 55 | # plink format (1-based, ['chr', 'variant_id', 'ignore1', 'pos', 'allele1', 'allele2']) 56 | awk 'BEGIN{OFS="\t"} {print $1,$9,"0",$2,$3,$4}' ${TEST_DIR}/caqtls.african.lcls.benchmarking.subset.tsv \ 57 | | tail -n+2 \ 58 | > ${TEST_DIR}/test.plink.tsv 59 | 60 | # original format (1-based, ['chr', 'pos', 'variant_id', 'allele1', 'allele2']) 61 | awk 'BEGIN{OFS="\t"} {print $1,$2,$9,$3,$4}' ${TEST_DIR}/caqtls.african.lcls.benchmarking.subset.tsv \ 62 | | tail -n+2 \ 63 | > ${TEST_DIR}/test.original.tsv -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/variant-scorer/0e1e34199e63112aa618748bb79a206fc491300a/src/__init__.py -------------------------------------------------------------------------------- /src/archive/old.variant_scoring.per_chrom.shuf_pvals.py: -------------------------------------------------------------------------------- 1 | from snp_generator import SNPGenerator 2 | from peak_generator import PeakGenerator 3 | from utils import argmanager, losses 4 | import scipy.stats 5 | from scipy.spatial.distance import jensenshannon 6 | from tensorflow.keras.utils import get_custom_objects 7 | from tensorflow.keras.models import load_model 8 | import tensorflow as tf 9 | import pandas as pd 10 | import os 11 | import argparse 12 | import numpy as np 13 | import h5py 14 | import psutil 15 | from tqdm import tqdm 16 | import statsmodels.stats.multitest 17 | 18 | 19 | SCHEMA = {'original': ["chr", "pos", "rsid", "allele1", "allele2"], 20 | 'plink': ["chr", "rsid", "ignore1", "pos", "allele1", "allele2"], 21 | 'narrowpeak': ['chr', 'start', 'end', 3, 4, 5, 6, 7, 'rank', 'summit'], 22 | 'bed': ['chr', 'start', 'pos', 'allele1', 'allele2', 'rsid', 'snp_id'], 23 | 'chrombpnet': ["chr", "pos", "allele1", "allele2", "rsid"]} 24 | 25 | def main(): 26 | args = argmanager.fetch_scoring_args() 27 | print(args) 28 | 29 | out_dir = os.path.sep.join(args.out_prefix.split(os.path.sep)[:-1]) 30 | if not os.path.exists(out_dir): 31 | raise OSError("Output directory does not exist") 32 | 33 | # load the model 34 | model = load_model_wrapper(args.model) 35 | 36 | # load the variants 37 | variants_table = pd.read_csv(args.list, header=None, sep='\t', names=SCHEMA[args.schema]) 38 | variants_table.drop(columns=[x for x in variants_table.columns if x.startswith('ignore')], inplace=True) 39 | variants_table['chr'] = variants_table['chr'].astype(str) 40 | has_chr_prefix = any('chr' in x for x in variants_table['chr'].tolist()) 41 | if not has_chr_prefix: 42 | variants_table['chr'] = 'chr' + variants_table['chr'] 43 | 44 | chrom_sizes = pd.read_csv(args.chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 45 | chrom_sizes_dict = chrom_sizes.set_index('chrom')['size'].to_dict() 46 | 47 | if args.chrom: 48 | variants_table = variants_table.loc[variants_table['chr'] == args.chrom] 49 | 50 | # infer input length 51 | if args.lite: 52 | input_len = model.input_shape[0][1] 53 | else: 54 | input_len = model.input_shape[1] 55 | print("input length inferred from the model: ", input_len) 56 | 57 | print(variants_table.shape) 58 | variants_table = variants_table.loc[variants_table.apply(lambda x: get_valid_variants(x.chr, x.pos, x.allele1, x.allele2, input_len, chrom_sizes_dict), axis=1)] 59 | print(variants_table.shape) 60 | 61 | variants_table.reset_index(drop=True, inplace=True) 62 | 63 | if args.max_shuf: 64 | if len(variants_table) > args.max_shuf: 65 | shuf_variants_table = variants_table.sample(args.max_shuf) 66 | args.num_shuf = 1 67 | else: 68 | shuf_variants_table = variants_table.copy() 69 | else: 70 | shuf_variants_table = variants_table.copy() 71 | 72 | print(shuf_variants_table.shape) 73 | 74 | shuf_rsids, shuf_allele1_count_preds, shuf_allele2_count_preds, \ 75 | shuf_allele1_profile_preds, shuf_allele2_profile_preds = fetch_variant_predictions(model, 76 | shuf_variants_table, 77 | input_len, 78 | args.genome, 79 | args.batch_size, 80 | debug_mode=args.debug_mode, 81 | lite=args.lite, 82 | bias=None, 83 | shuf=True, 84 | num_shuf=args.num_shuf) 85 | 86 | if args.peaks: 87 | if args.peak_chrom_sizes == None: 88 | args.peak_chrom_sizes = args.chrom_sizes 89 | if args.peak_genome == None: 90 | args.peak_genome = args.genome 91 | 92 | peak_chrom_sizes = pd.read_csv(args.peak_chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 93 | peak_chrom_sizes_dict = peak_chrom_sizes.set_index('chrom')['size'].to_dict() 94 | 95 | peaks = pd.read_csv(args.peaks, header=None, sep='\t', names=SCHEMA['narrowpeak']) 96 | peaks.sort_values(by=['chr', 'start', 'end', 'summit', 'rank'], ascending=[True, True, True, True, False], inplace=True) 97 | peaks.drop_duplicates(subset=['chr', 'start', 'end', 'summit'], inplace=True) 98 | 99 | print(peaks.shape) 100 | peaks = peaks.loc[peaks.apply(lambda x: get_valid_peaks(x.chr, x.start, x.summit, input_len, peak_chrom_sizes_dict), axis=1)] 101 | 102 | if args.debug_mode: 103 | peaks = peaks.sample(10000) 104 | 105 | if args.max_peaks: 106 | if len(peaks) > args.max_peaks: 107 | peaks = peaks.sample(args.max_peaks) 108 | 109 | peaks.reset_index(drop=True, inplace=True) 110 | print(peaks.shape) 111 | 112 | count_preds, profile_preds = fetch_peak_predictions(model, 113 | peaks, 114 | input_len, 115 | args.peak_genome, 116 | args.batch_size, 117 | debug_mode=args.debug_mode, 118 | lite=args.lite, 119 | bias=None) 120 | 121 | shuf_log_fold_change, shuf_profile_jsd, \ 122 | shuf_allele1_percentile, shuf_allele2_percentile, \ 123 | shuf_percentile_change = get_variant_scores_with_peaks(shuf_allele1_count_preds, shuf_allele2_count_preds, 124 | shuf_allele1_profile_preds, shuf_allele2_profile_preds, 125 | count_preds) 126 | 127 | shuf_max_percentile = np.maximum(shuf_allele1_percentile, shuf_allele2_percentile) 128 | shuf_logfc_jsd = np.squeeze(np.abs(shuf_log_fold_change)) * shuf_profile_jsd 129 | shuf_logfc_jsd_max_percentile = shuf_logfc_jsd * shuf_max_percentile 130 | 131 | else: 132 | shuf_log_fold_change, shuf_profile_jsd = get_variant_scores(shuf_allele1_count_preds, shuf_allele2_count_preds, 133 | shuf_allele1_profile_preds, shuf_allele2_profile_preds) 134 | 135 | todo_chroms = [x for x in variants_table.chr.unique() if not os.path.exists('.'.join([args.out_prefix, x, "variant_predictions.h5"]))] 136 | 137 | for chrom in todo_chroms: 138 | print() 139 | print(chrom) 140 | print() 141 | 142 | chrom_variants_table = variants_table.loc[variants_table['chr'] == chrom].sort_values(by='pos').copy() 143 | chrom_variants_table.reset_index(drop=True, inplace=True) 144 | print(chrom_variants_table.shape) 145 | print() 146 | 147 | if args.debug_mode: 148 | chrom_variants_table = chrom_variants_table.sample(100000) 149 | print(chrom_variants_table.head()) 150 | 151 | chrom_variants_table.reset_index(drop=True, inplace=True) 152 | 153 | # fetch model prediction for variants 154 | rsids, allele1_count_preds, allele2_count_preds, \ 155 | allele1_profile_preds, allele2_profile_preds = fetch_variant_predictions(model, 156 | chrom_variants_table, 157 | input_len, 158 | args.genome, 159 | args.batch_size, 160 | debug_mode=args.debug_mode, 161 | lite=args.lite, 162 | bias=None, 163 | shuf=False, 164 | num_shuf=args.num_shuf) 165 | 166 | if args.peaks: 167 | log_fold_change, profile_jsd, \ 168 | allele1_percentile, allele2_percentile, \ 169 | percentile_change = get_variant_scores_with_peaks(allele1_count_preds, allele2_count_preds, 170 | allele1_profile_preds, allele2_profile_preds, 171 | count_preds) 172 | 173 | else: 174 | log_fold_change, profile_jsd = get_variant_scores(allele1_count_preds, allele2_count_preds, allele1_profile_preds, allele2_profile_preds) 175 | 176 | # unpack rsids to write outputs and write score to output 177 | assert np.array_equal(chrom_variants_table["rsid"].tolist(), rsids) 178 | chrom_variants_table["log_fold_change"] = log_fold_change 179 | chrom_variants_table["profile_jsd"] = profile_jsd 180 | chrom_variants_table["logfc_jsd"] = abs(chrom_variants_table["log_fold_change"]) * chrom_variants_table["profile_jsd"] 181 | chrom_variants_table["allele1_pred_counts"] = allele1_count_preds 182 | chrom_variants_table["allele2_pred_counts"] = allele2_count_preds 183 | chrom_variants_table["log_fold_change_pval"] = chrom_variants_table["log_fold_change"].apply(lambda x: 184 | 2 * min(scipy.stats.percentileofscore(shuf_log_fold_change, x) / 100, 185 | 1 - (scipy.stats.percentileofscore(shuf_log_fold_change, x) / 100))) 186 | chrom_variants_table["profile_jsd_pval"] = chrom_variants_table["profile_jsd"].apply(lambda x: 187 | 1 - (scipy.stats.percentileofscore(shuf_profile_jsd, x) / 100)) 188 | chrom_variants_table["logfc_jsd_pval"] = chrom_variants_table["logfc_jsd"].apply(lambda x: 189 | 1 - (scipy.stats.percentileofscore(shuf_logfc_jsd, x) / 100)) 190 | 191 | if args.peaks: 192 | chrom_variants_table["allele1_percentile"] = allele1_percentile 193 | chrom_variants_table["allele2_percentile"] = allele2_percentile 194 | chrom_variants_table["max_percentile"] = chrom_variants_table[["allele1_percentile", "allele2_percentile"]].max(axis=1) 195 | chrom_variants_table["logfc_jsd_max_percentile"] = chrom_variants_table["logfc_jsd"] * chrom_variants_table["max_percentile"] 196 | chrom_variants_table["logfc_jsd_max_percentile_pval"] = chrom_variants_table["logfc_jsd_max_percentile"].apply(lambda x: 197 | 1 - (scipy.stats.percentileofscore(shuf_logfc_jsd_max_percentile, x) / 100)) 198 | chrom_variants_table["percentile_change"] = percentile_change 199 | chrom_variants_table["percentile_change_pval"] = chrom_variants_table["percentile_change"].apply(lambda x: 200 | 2 * min(scipy.stats.percentileofscore(shuf_percentile_change, x) / 100, 201 | 1 - (scipy.stats.percentileofscore(shuf_percentile_change, x) / 100))) 202 | 203 | chrom_variants_table.to_csv('.'.join([args.out_prefix, chrom, "variant_scores.tsv"]), sep="\t", index=False) 204 | 205 | # store predictions at variants 206 | with h5py.File('.'.join([args.out_prefix, chrom, "variant_predictions.h5"]), 'w') as f: 207 | wo_bias = f.create_group('wo_bias') 208 | wo_bias.create_dataset('allele1_pred_counts', data=allele1_count_preds) 209 | wo_bias.create_dataset('allele2_pred_counts', data=allele2_count_preds) 210 | wo_bias.create_dataset('allele1_pred_profile', data=allele1_profile_preds) 211 | wo_bias.create_dataset('allele2_pred_profile', data=allele2_profile_preds) 212 | wo_bias.create_dataset('shuf_log_fold_change', data=shuf_log_fold_change) 213 | wo_bias.create_dataset('shuf_profile_jsd', data=shuf_profile_jsd) 214 | wo_bias.create_dataset('shuf_logfc_jsd', data=shuf_logfc_jsd) 215 | if args.peaks: 216 | wo_bias.create_dataset('shuf_percentile_change', data=shuf_percentile_change) 217 | wo_bias.create_dataset('shuf_max_percentile', data=shuf_max_percentile) 218 | wo_bias.create_dataset('shuf_logfc_jsd_max_percentile', data=shuf_logfc_jsd_max_percentile) 219 | 220 | print(str(chrom) + " DONE") 221 | 222 | 223 | def poisson_pval(allele1_counts, allele2_counts): 224 | if allele2_counts > allele1_counts: 225 | pval = 1 - scipy.stats.poisson.cdf(allele2_counts, allele1_counts) 226 | else: 227 | pval = scipy.stats.poisson.cdf(allele2_counts, allele1_counts) 228 | pval = pval * 2 229 | return pval 230 | 231 | def poisson_pval_best(allele1_counts, allele2_counts): 232 | if allele2_counts > allele1_counts: 233 | pval_1 = 1 - scipy.stats.poisson.cdf(allele2_counts, allele1_counts) 234 | pval_2 = scipy.stats.poisson.cdf(allele1_counts, allele2_counts) 235 | else: 236 | pval_1 = scipy.stats.poisson.cdf(allele2_counts, allele1_counts) 237 | pval_2 = 1 - scipy.stats.poisson.cdf(allele1_counts, allele2_counts) 238 | pval = min([pval_1, pval_2]) * 2 239 | return pval 240 | 241 | def poisson_pval_worst(allele1_counts, allele2_counts): 242 | if allele2_counts > allele1_counts: 243 | pval_1 = 1 - scipy.stats.poisson.cdf(allele2_counts, allele1_counts) 244 | pval_2 = scipy.stats.poisson.cdf(allele1_counts, allele2_counts) 245 | else: 246 | pval_1 = scipy.stats.poisson.cdf(allele2_counts, allele1_counts) 247 | pval_2 = 1 - scipy.stats.poisson.cdf(allele1_counts, allele2_counts) 248 | pval = max([pval_1, pval_2]) * 2 249 | return pval 250 | 251 | def get_valid_peaks(chrom, pos, summit, input_len, chrom_sizes_dict): 252 | flank = input_len // 2 253 | lower_check = ((pos + summit) - flank > 0) 254 | upper_check = ((pos + summit) + flank <= chrom_sizes_dict[chrom]) 255 | in_bounds = lower_check and upper_check 256 | return in_bounds 257 | 258 | def get_valid_variants(chrom, pos, allele1, allele2, input_len, chrom_sizes_dict): 259 | flank = input_len // 2 260 | lower_check = (pos - flank > 0) 261 | upper_check = (pos + flank <= chrom_sizes_dict[chrom]) 262 | in_bounds = lower_check and upper_check 263 | no_allele1_indel = (len(allele1) == 1) 264 | no_allele2_indel = (len(allele2) == 1) 265 | no_indels = no_allele1_indel and no_allele2_indel 266 | valid_variants = in_bounds and no_indels 267 | return valid_variants 268 | 269 | def softmax(x, temp=1): 270 | norm_x = x - np.mean(x,axis=1, keepdims=True) 271 | return np.exp(temp*norm_x)/np.sum(np.exp(temp*norm_x), axis=1, keepdims=True) 272 | 273 | def load_model_wrapper(model_file): 274 | # read .h5 model 275 | custom_objects = {"multinomial_nll": losses.multinomial_nll, "tf": tf} 276 | get_custom_objects().update(custom_objects) 277 | model = load_model(model_file, compile=False) 278 | print("model loaded succesfully") 279 | return model 280 | 281 | def fetch_peak_predictions(model, peaks, input_len, genome_fasta, batch_size, debug_mode=False, lite=False, bias=None): 282 | count_preds = [] 283 | profile_preds = [] 284 | 285 | # peak sequence generator 286 | peak_gen = PeakGenerator(peaks=peaks, 287 | input_len=input_len, 288 | genome_fasta=genome_fasta, 289 | batch_size=batch_size, 290 | debug_mode=debug_mode) 291 | 292 | for i in tqdm(range(len(peak_gen))): 293 | 294 | seqs = peak_gen[i] 295 | 296 | if lite: 297 | if bias != None: 298 | bias_batch_preds = bias.predict(seqs, verbose=False) 299 | 300 | batch_preds = model.predict([seqs, 301 | bias_batch_preds[0], 302 | bias_batch_preds[1]], 303 | verbose=False) 304 | else: 305 | batch_preds = model.predict([seqs, 306 | np.zeros((len(seqs), model.output_shape[0][1])), 307 | np.zeros((len(seqs), ))], 308 | verbose=False) 309 | else: 310 | if bias != None: 311 | batch_preds = bias.predict(seqs, verbose=False) 312 | else: 313 | batch_preds = model.predict(seqs, verbose=False) 314 | 315 | count_preds.extend(np.exp(np.squeeze(batch_preds[1])) - 1) 316 | profile_preds.extend(np.squeeze(softmax(batch_preds[0]))) 317 | 318 | count_preds = np.array(count_preds) 319 | profile_preds = np.array(profile_preds) 320 | 321 | return count_preds, profile_preds 322 | 323 | def fetch_variant_predictions(model, variants_table, input_len, genome_fasta, batch_size, debug_mode=False, lite=False, bias=None, shuf=False, num_shuf=10): 324 | rsids = [] 325 | allele1_count_preds = [] 326 | allele2_count_preds = [] 327 | allele1_profile_preds = [] 328 | allele2_profile_preds = [] 329 | 330 | # snp sequence generator 331 | snp_gen = SNPGenerator(variants_table=variants_table, 332 | input_len=input_len, 333 | genome_fasta=genome_fasta, 334 | batch_size=batch_size, 335 | debug_mode=debug_mode, 336 | shuf=shuf, 337 | num_shuf=num_shuf) 338 | 339 | for i in tqdm(range(len(snp_gen))): 340 | 341 | batch_rsids, allele1_seqs, allele2_seqs = snp_gen[i] 342 | 343 | if lite: 344 | if bias != None: 345 | allele1_bias_batch_preds = bias.predict(allele1_seqs, verbose=False) 346 | allele2_bias_batch_preds = bias.predict(allele2_seqs, verbose=False) 347 | 348 | allele1_batch_preds = model.predict([allele1_seqs, 349 | allele1_bias_batch_preds[0], 350 | allele1_bias_batch_preds[1]], 351 | verbose=False) 352 | allele2_batch_preds = model.predict([allele2_seqs, 353 | allele2_bias_batch_preds[0], 354 | allele2_bias_batch_preds[1]], 355 | verbose=False) 356 | else: 357 | allele1_batch_preds = model.predict([allele1_seqs, 358 | np.zeros((len(allele1_seqs), model.output_shape[0][1])), 359 | np.zeros((len(allele1_seqs), ))], 360 | verbose=False) 361 | allele2_batch_preds = model.predict([allele2_seqs, 362 | np.zeros((len(allele2_seqs), model.output_shape[0][1])), 363 | np.zeros((len(allele2_seqs), ))], 364 | verbose=False) 365 | else: 366 | if bias != None: 367 | allele1_batch_preds = bias.predict(allele1_seqs, verbose=False) 368 | allele2_batch_preds = bias.predict(allele2_seqs, verbose=False) 369 | else: 370 | allele1_batch_preds = model.predict(allele1_seqs, verbose=False) 371 | allele2_batch_preds = model.predict(allele2_seqs, verbose=False) 372 | 373 | allele1_batch_preds[1] = np.array([allele1_batch_preds[1][i] for i in range(len(allele1_batch_preds[1]))]) 374 | allele2_batch_preds[1] = np.array([allele2_batch_preds[1][i] for i in range(len(allele2_batch_preds[1]))]) 375 | 376 | allele1_count_preds.extend(np.exp(allele1_batch_preds[1])) 377 | allele2_count_preds.extend(np.exp(allele2_batch_preds[1])) 378 | 379 | allele1_profile_preds.extend(np.squeeze(softmax(allele1_batch_preds[0]))) 380 | allele2_profile_preds.extend(np.squeeze(softmax(allele2_batch_preds[0]))) 381 | 382 | rsids.extend(batch_rsids) 383 | 384 | rsids = np.array(rsids) 385 | allele1_count_preds = np.array(allele1_count_preds) 386 | allele2_count_preds = np.array(allele2_count_preds) 387 | allele1_profile_preds = np.array(allele1_profile_preds) 388 | allele2_profile_preds = np.array(allele2_profile_preds) 389 | 390 | return rsids, allele1_count_preds, allele2_count_preds, \ 391 | allele1_profile_preds, allele2_profile_preds 392 | 393 | def get_variant_scores_with_peaks(allele1_count_preds, allele2_count_preds, 394 | allele1_profile_preds, allele2_profile_preds, count_preds): 395 | log_fold_change = np.log2(allele2_count_preds / allele1_count_preds) 396 | profile_jsd_diff = np.array([jensenshannon(x,y) for x,y in zip(allele2_profile_preds, allele1_profile_preds)]) 397 | allele1_percentile = np.array([np.mean(count_preds < x) for x in allele1_count_preds]) 398 | allele2_percentile = np.array([np.mean(count_preds < x) for x in allele2_count_preds]) 399 | percentile_change = allele2_percentile - allele1_percentile 400 | 401 | return log_fold_change, profile_jsd_diff, allele1_percentile, allele2_percentile, percentile_change 402 | 403 | def get_variant_scores(allele1_count_preds, allele2_count_preds, 404 | allele1_profile_preds, allele2_profile_preds): 405 | log_fold_change = np.log2(allele2_count_preds / allele1_count_preds) 406 | profile_jsd_diff = np.array([jensenshannon(x,y) for x,y in zip(allele2_profile_preds, allele1_profile_preds)]) 407 | 408 | return log_fold_change, profile_jsd_diff 409 | 410 | if __name__ == "__main__": 411 | main() 412 | 413 | -------------------------------------------------------------------------------- /src/archive/variant_scoring.per_chunk.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.utils import get_custom_objects 2 | from tensorflow.keras.models import load_model 3 | import tensorflow as tf 4 | import scipy.stats 5 | from scipy.spatial.distance import jensenshannon 6 | import pandas as pd 7 | import os 8 | import argparse 9 | import numpy as np 10 | import h5py 11 | import math 12 | from generators.snp_generator import SNPGenerator 13 | from generators.peak_generator import PeakGenerator 14 | from utils import argmanager, losses 15 | from utils.helpers import * 16 | 17 | 18 | def main(): 19 | args = argmanager.fetch_scoring_args() 20 | print(args) 21 | 22 | np.random.seed(args.random_seed) 23 | 24 | out_dir = os.path.sep.join(args.out_prefix.split(os.path.sep)[:-1]) 25 | if not os.path.exists(out_dir): 26 | raise OSError("Output directory does not exist") 27 | 28 | # load the model 29 | model = load_model_wrapper(args.model) 30 | 31 | # load the variants 32 | variants_table = pd.read_csv(args.list, header=None, sep='\t', names=get_snp_schema(args.schema)) 33 | variants_table.drop(columns=[str(x) for x in variants_table.columns if str(x).startswith('ignore')], inplace=True) 34 | variants_table['chr'] = variants_table['chr'].astype(str) 35 | has_chr_prefix = any('chr' in x.lower() for x in variants_table['chr'].tolist()) 36 | if not has_chr_prefix: 37 | variants_table['chr'] = 'chr' + variants_table['chr'] 38 | 39 | chrom_sizes = pd.read_csv(args.chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 40 | chrom_sizes_dict = chrom_sizes.set_index('chrom')['size'].to_dict() 41 | 42 | print("Original variants table shape:", variants_table.shape) 43 | 44 | if args.chrom: 45 | variants_table = variants_table.loc[variants_table['chr'] == args.chrom] 46 | print("Chromosome variants table shape:", variants_table.shape) 47 | 48 | # infer input length 49 | if args.lite: 50 | input_len = model.input_shape[0][1] 51 | else: 52 | input_len = model.input_shape[1] 53 | 54 | print("Input length inferred from the model:", input_len) 55 | 56 | variants_table = variants_table.loc[variants_table.apply(lambda x: get_valid_variants(x.chr, x.pos, x.allele1, x.allele2, input_len, chrom_sizes_dict), axis=1)] 57 | variants_table.reset_index(drop=True, inplace=True) 58 | 59 | print("Final variants table shape:", variants_table.shape) 60 | 61 | if args.total_shuf: 62 | if len(variants_table) > args.total_shuf: 63 | shuf_variants_table = variants_table.sample(args.total_shuf, 64 | random_state=args.random_seed, 65 | ignore_index=True, 66 | replace=False) 67 | else: 68 | shuf_variants_table = variants_table.sample(args.total_shuf, 69 | random_state=args.random_seed, 70 | ignore_index=True, 71 | replace=True) 72 | else: 73 | total_shuf = len(variants_table) * args.num_shuf 74 | shuf_variants_table = variants_table.sample(args.total_shuf, 75 | random_state=args.random_seed, 76 | ignore_index=True, 77 | replace=True) 78 | 79 | shuf_variants_table['random_seed'] = np.random.permutation(len(shuf_variants_table)) 80 | 81 | print("Shuffled variants table shape:", shuf_variants_table.shape) 82 | 83 | if len(shuf_variants_table) > 0: 84 | if args.debug_mode: 85 | shuf_variants_table = shuf_variants_table.sample(10000, random_state=args.random_seed, ignore_index=True) 86 | print() 87 | print(shuf_variants_table.head()) 88 | print("Debug shuffled variants table shape:", shuf_variants_table.shape) 89 | print() 90 | 91 | shuf_rsids, shuf_allele1_pred_counts, shuf_allele2_pred_counts, \ 92 | shuf_allele1_pred_profiles, shuf_allele2_pred_profiles = fetch_variant_predictions(model, 93 | shuf_variants_table, 94 | input_len, 95 | args.genome, 96 | args.batch_size, 97 | debug_mode=args.debug_mode, 98 | lite=args.lite, 99 | shuf=True) 100 | 101 | if args.peaks: 102 | if args.peak_chrom_sizes == None: 103 | args.peak_chrom_sizes = args.chrom_sizes 104 | if args.peak_genome == None: 105 | args.peak_genome = args.genome 106 | 107 | peak_chrom_sizes = pd.read_csv(args.peak_chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 108 | peak_chrom_sizes_dict = peak_chrom_sizes.set_index('chrom')['size'].to_dict() 109 | 110 | peaks = pd.read_csv(args.peaks, header=None, sep='\t', names=get_peak_schema('narrowpeak')) 111 | 112 | print("Original peak table shape:", peaks.shape) 113 | 114 | peaks.sort_values(by=['chr', 'start', 'end', 'summit', 'rank'], ascending=[True, True, True, True, False], inplace=True) 115 | peaks.drop_duplicates(subset=['chr', 'start', 'end', 'summit'], inplace=True) 116 | peaks = peaks.loc[peaks.apply(lambda x: get_valid_peaks(x.chr, x.start, x.summit, input_len, peak_chrom_sizes_dict), axis=1)] 117 | peaks.reset_index(drop=True, inplace=True) 118 | 119 | print("De-duplicated peak table shape:", peaks.shape) 120 | 121 | if args.debug_mode: 122 | peaks = peaks.sample(10000, random_state=args.random_seed, ignore_index=True) 123 | print() 124 | print(peaks.head()) 125 | print("Debug peak table shape:", peaks.shape) 126 | print() 127 | 128 | if args.max_peaks: 129 | if len(peaks) > args.max_peaks: 130 | peaks = peaks.sample(args.max_peaks, random_state=args.random_seed, ignore_index=True) 131 | print("Subsampled peak table shape:", peaks.shape) 132 | 133 | pred_counts, pred_profiles = fetch_peak_predictions(model, 134 | peaks, 135 | input_len, 136 | args.peak_genome, 137 | args.batch_size, 138 | debug_mode=args.debug_mode, 139 | lite=args.lite) 140 | 141 | if len(shuf_variants_table) > 0: 142 | shuf_logfc, shuf_jsd, \ 143 | shuf_allele1_percentile, shuf_allele2_percentile = get_variant_scores_with_peaks(shuf_allele1_pred_counts, 144 | shuf_allele2_pred_counts, 145 | shuf_allele1_pred_profiles, 146 | shuf_allele2_pred_profiles, 147 | pred_counts) 148 | 149 | shuf_max_percentile = np.maximum(shuf_allele1_percentile, shuf_allele2_percentile) 150 | shuf_percentile_change = shuf_allele2_percentile - shuf_allele1_percentile 151 | shuf_abs_logfc = np.squeeze(np.abs(shuf_logfc)) 152 | shuf_abs_logfc_jsd = shuf_abs_logfc * shuf_jsd 153 | shuf_abs_logfc_jsd_max_percentile = shuf_abs_logfc_jsd * shuf_max_percentile 154 | 155 | else: 156 | if len(shuf_variants_table) > 0: 157 | shuf_logfc, shuf_jsd = get_variant_scores(shuf_allele1_pred_counts, 158 | shuf_allele2_pred_counts, 159 | shuf_allele1_pred_profiles, 160 | shuf_allele2_pred_profiles) 161 | shuf_abs_logfc = np.squeeze(np.abs(shuf_logfc)) 162 | shuf_abs_logfc_jsd = shuf_abs_logfc * shuf_jsd 163 | 164 | todo_chunks = [x for x in range(args.num_chunks) if not os.path.exists('.'.join([args.out_prefix, str(x), "variant_predictions.h5"]))] 165 | chunk_frac = (1 / args.num_chunks) 166 | 167 | for chunk in todo_chunks: 168 | print() 169 | print(chunk) 170 | print() 171 | 172 | chunk_variants_table = variants_table.iloc[math.ceil((len(variants_table) / args.num_chunks) * chunk): 173 | math.ceil((len(variants_table) / args.num_chunks) * (chunk + 1))].copy() 174 | chunk_variants_table.reset_index(drop=True, inplace=True) 175 | 176 | print(str(chunk) + " variants table shape:", chunk_variants_table.shape) 177 | print() 178 | 179 | if args.debug_mode: 180 | chunk_variants_table = chunk_variants_table.sample(10000, random_state=args.random_seed, ignore_index=True) 181 | print() 182 | print(chunk_variants_table.head()) 183 | print("Debug " + str(chunk) + " variants table shape:", chunk_variants_table.shape) 184 | print() 185 | 186 | # fetch model prediction for variants 187 | rsids, allele1_pred_counts, allele2_pred_counts, \ 188 | allele1_pred_profiles, allele2_pred_profiles = fetch_variant_predictions(model, 189 | chunk_variants_table, 190 | input_len, 191 | args.genome, 192 | args.batch_size, 193 | debug_mode=args.debug_mode, 194 | lite=args.lite, 195 | shuf=False) 196 | 197 | if args.peaks: 198 | logfc, jsd, \ 199 | allele1_percentile, allele2_percentile = get_variant_scores_with_peaks(allele1_pred_counts, 200 | allele2_pred_counts, 201 | allele1_pred_profiles, 202 | allele2_pred_profiles, 203 | pred_counts) 204 | 205 | else: 206 | logfc, jsd = get_variant_scores(allele1_pred_counts, 207 | allele2_pred_counts, 208 | allele1_pred_profiles, 209 | allele2_pred_profiles) 210 | 211 | # unpack rsids to write outputs and write score to output 212 | assert np.array_equal(chunk_variants_table["rsid"].tolist(), rsids) 213 | chunk_variants_table["allele1_pred_counts"] = allele1_pred_counts 214 | chunk_variants_table["allele2_pred_counts"] = allele2_pred_counts 215 | chunk_variants_table["logfc"] = logfc 216 | chunk_variants_table["abs_logfc"] = abs(chunk_variants_table["logfc"]) 217 | chunk_variants_table["jsd"] = jsd 218 | chunk_variants_table["abs_logfc_x_jsd"] = chunk_variants_table["abs_logfc"] * chunk_variants_table["jsd"] 219 | 220 | if len(shuf_variants_table) > 0: 221 | chunk_variants_table["logfc.pval"] = chunk_variants_table["logfc"].apply(lambda x: 222 | 2 * min(scipy.stats.percentileofscore(shuf_logfc, x) / 100, 223 | 1 - (scipy.stats.percentileofscore(shuf_logfc, x) / 100))) 224 | chunk_variants_table["jsd.pval"] = chunk_variants_table["jsd"].apply(lambda x: 225 | 1 - (scipy.stats.percentileofscore(shuf_jsd, x) / 100)) 226 | chunk_variants_table["abs_logfc_x_jsd.pval"] = chunk_variants_table["abs_logfc_x_jsd"].apply(lambda x: 227 | 1 - (scipy.stats.percentileofscore(shuf_abs_logfc_jsd, x) / 100)) 228 | 229 | if args.peaks: 230 | chunk_variants_table["allele1_percentile"] = allele1_percentile 231 | chunk_variants_table["allele2_percentile"] = allele2_percentile 232 | chunk_variants_table["max_percentile"] = chunk_variants_table[["allele1_percentile", "allele2_percentile"]].max(axis=1) 233 | chunk_variants_table["percentile_change"] = chunk_variants_table["allele2_percentile"] - chunk_variants_table["allele1_percentile"] 234 | chunk_variants_table["abs_logfc_x_jsd_x_max_percentile"] = chunk_variants_table["abs_logfc_x_jsd"] * chunk_variants_table["max_percentile"] 235 | 236 | if len(shuf_variants_table) > 0: 237 | chunk_variants_table["max_percentile.pval"] = chunk_variants_table["max_percentile"].apply(lambda x: 238 | 1 - (scipy.stats.percentileofscore(shuf_max_percentile, x) / 100)) 239 | chunk_variants_table["percentile_change.pval"] = chunk_variants_table["percentile_change"].apply(lambda x: 240 | 2 * min(scipy.stats.percentileofscore(shuf_percentile_change, x) / 100, 241 | 1 - (scipy.stats.percentileofscore(shuf_percentile_change, x) / 100))) 242 | chunk_variants_table["abs_logfc_x_jsd_x_max_percentile.pval"] = chunk_variants_table["abs_logfc_x_jsd_x_max_percentile"].apply(lambda x: 243 | 1 - (scipy.stats.percentileofscore(shuf_abs_logfc_jsd_max_percentile, x) / 100)) 244 | 245 | if args.schema == "bed": 246 | chunk_variants_table['pos'] = chunk_variants_table['pos'] - 1 247 | 248 | print() 249 | print(chunk_variants_table.head()) 250 | print("Output " + str(chunk) + " score table shape:", chunk_variants_table.shape) 251 | print() 252 | 253 | chunk_variants_table.to_csv('.'.join([args.out_prefix, str(chunk), "variant_scores.tsv"]), sep="\t", index=False) 254 | 255 | # store predictions at variants 256 | if not args.no_hdf5: 257 | with h5py.File('.'.join([args.out_prefix, str(chunk), "variant_predictions.h5"]), 'w') as f: 258 | observed = f.create_group('observed') 259 | observed.create_dataset('allele1_pred_counts', data=allele1_pred_counts, compression='gzip', compression_opts=9) 260 | observed.create_dataset('allele2_pred_counts', data=allele2_pred_counts, compression='gzip', compression_opts=9) 261 | observed.create_dataset('allele1_pred_profiles', data=allele1_pred_profiles, compression='gzip', compression_opts=9) 262 | observed.create_dataset('allele2_pred_profiles', data=allele2_pred_profiles, compression='gzip', compression_opts=9) 263 | if len(shuf_variants_table) > 0: 264 | shuffled = f.create_group('shuffled') 265 | shuffled.create_dataset('shuf_allele1_pred_counts', data=shuf_allele1_pred_counts, compression='gzip', compression_opts=9) 266 | shuffled.create_dataset('shuf_allele2_pred_counts', data=shuf_allele2_pred_counts, compression='gzip', compression_opts=9) 267 | shuffled.create_dataset('shuf_logfc', data=shuf_logfc, compression='gzip', compression_opts=9) 268 | shuffled.create_dataset('shuf_abs_logfc', data=shuf_abs_logfc, compression='gzip', compression_opts=9) 269 | shuffled.create_dataset('shuf_jsd', data=shuf_jsd, compression='gzip', compression_opts=9) 270 | shuffled.create_dataset('shuf_abs_logfc_x_jsd', data=shuf_abs_logfc_jsd, compression='gzip', compression_opts=9) 271 | if args.peaks: 272 | shuffled.create_dataset('shuf_max_percentile', data=shuf_max_percentile, compression='gzip', compression_opts=9) 273 | shuffled.create_dataset('shuf_percentile_change', data=shuf_percentile_change, compression='gzip', compression_opts=9) 274 | shuffled.create_dataset('shuf_abs_logfc_x_jsd_x_max_percentile', data=shuf_abs_logfc_jsd_max_percentile, compression='gzip', compression_opts=9) 275 | 276 | print("DONE:", str(chunk)) 277 | print() 278 | 279 | 280 | if __name__ == "__main__": 281 | main() 282 | -------------------------------------------------------------------------------- /src/generators/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/variant-scorer/0e1e34199e63112aa618748bb79a206fc491300a/src/generators/__init__.py -------------------------------------------------------------------------------- /src/generators/peak_generator.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.utils import Sequence 2 | import pandas as pd 3 | import numpy as np 4 | import math 5 | import pyfaidx 6 | from utils import one_hot 7 | 8 | 9 | class PeakGenerator(Sequence): 10 | def __init__(self, 11 | peaks, 12 | input_len, 13 | genome_fasta, 14 | batch_size=512, 15 | debug_mode=False): 16 | 17 | self.peaks = peaks 18 | self.num_peaks = self.peaks.shape[0] 19 | self.input_len = input_len 20 | self.genome = pyfaidx.Fasta(genome_fasta) 21 | self.debug_mode = debug_mode 22 | self.flank_size = self.input_len // 2 23 | self.batch_size = batch_size 24 | 25 | def __get_seq__(self, chrom, start, summit): 26 | chrom = str(chrom) 27 | start = int(start) 28 | summit = int(start) + int(summit) 29 | flank_start = int(summit - self.flank_size) 30 | flank_end = int(summit + (self.flank_size - 1)) 31 | flank = str(self.genome.get_seq(chrom, flank_start, flank_end)) 32 | return flank 33 | 34 | def __getitem__(self, idx): 35 | cur_entries = self.peaks.iloc[idx*self.batch_size:min([self.num_peaks,(idx+1)*self.batch_size])] 36 | peak_ids = cur_entries['chr'] + ':' + cur_entries['start'].astype(str) + '-' + cur_entries['end'].astype(str) 37 | 38 | seqs = [self.__get_seq__(x, y, z) for x,y,z in 39 | zip(cur_entries.chr, cur_entries.start, cur_entries.summit)] 40 | 41 | return peak_ids, one_hot.dna_to_one_hot(seqs) 42 | 43 | def __len__(self): 44 | return math.ceil(self.num_peaks/self.batch_size) 45 | -------------------------------------------------------------------------------- /src/generators/variant_generator.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.utils import Sequence 2 | import pandas as pd 3 | import numpy as np 4 | import math 5 | import pyfaidx 6 | from utils import one_hot 7 | from deeplift.dinuc_shuffle import dinuc_shuffle 8 | 9 | 10 | class VariantGenerator(Sequence): 11 | def __init__(self, 12 | variants_table, 13 | input_len, 14 | genome_fasta, 15 | batch_size=512, 16 | debug_mode=False, 17 | shuf=False): 18 | 19 | self.variants_table = variants_table 20 | self.num_variants = self.variants_table.shape[0] 21 | self.input_len = input_len 22 | self.genome = pyfaidx.Fasta(genome_fasta) 23 | self.debug_mode = debug_mode 24 | self.flank_size = self.input_len // 2 25 | self.shuf = shuf 26 | self.batch_size = batch_size 27 | 28 | def __get_allele_seq__(self, chrom, pos, allele1, allele2, seed=-1): 29 | chrom = str(chrom) 30 | pos = int(pos) 31 | allele1 = str(allele1) 32 | allele2 = str(allele2) 33 | 34 | if allele1 == "-": 35 | allele1 = "" 36 | if allele2 == "-": 37 | allele2 = "" 38 | ### 1 - indexed position 39 | pos = pos - 1 40 | 41 | if len(allele1) == len(allele2): 42 | flank = str(self.genome[chrom][pos-self.flank_size:pos+self.flank_size]) 43 | if self.shuf: 44 | assert seed != -1 45 | flank = dinuc_shuffle(flank, rng=np.random.RandomState(seed)) 46 | allele1_seq = flank[:self.flank_size] + allele1 + flank[self.flank_size+len(allele1):] 47 | allele2_seq = flank[:self.flank_size] + allele2 + flank[self.flank_size+len(allele2):] 48 | 49 | ### handle INDELS (allele1 must be the reference allele) 50 | # Here, we adjust the flanks to account for the INDEL and insure that 51 | # the allele1 and allele2 sequences are the same length 52 | else: 53 | ### hg19 has lower case 54 | assert len(allele1) != len(allele2) 55 | assert self.genome[chrom][pos:pos+len(allele1)].seq.upper() == allele1 56 | mismatch_length = len(allele1) - len(allele2) 57 | if mismatch_length > 0: 58 | flank = str(self.genome[chrom][pos-self.flank_size:pos+self.flank_size+mismatch_length]) 59 | else: 60 | flank = str(self.genome[chrom][pos-self.flank_size:pos+self.flank_size]) 61 | 62 | if self.shuf: 63 | assert seed != -1 64 | flank = dinuc_shuffle(flank, rng=np.random.RandomState(seed)) 65 | 66 | left_flank=flank[:self.flank_size] 67 | 68 | allele1_right_flank = flank[self.flank_size+len(allele1):self.flank_size*2] 69 | allele2_right_flank = flank[self.flank_size+len(allele1):self.flank_size*2+mismatch_length] 70 | 71 | allele1_seq = left_flank + allele1 + allele1_right_flank 72 | allele2_seq = left_flank + allele2 + allele2_right_flank 73 | 74 | assert len(allele1_seq) == self.flank_size * 2 75 | assert len(allele2_seq) == self.flank_size * 2 76 | return allele1_seq, allele2_seq 77 | 78 | def __getitem__(self, idx): 79 | cur_entries = self.variants_table.iloc[idx*self.batch_size:min([self.num_variants,(idx+1)*self.batch_size])] 80 | variant_ids = cur_entries['variant_id'].tolist() 81 | 82 | if self.shuf: 83 | allele1_seqs, allele2_seqs = zip(*[self.__get_allele_seq__(v, w, x, y, z) for v,w,x,y,z in 84 | zip(cur_entries.chr, cur_entries.pos, 85 | cur_entries.allele1, cur_entries.allele2, cur_entries.random_seed)]) 86 | else: 87 | allele1_seqs, allele2_seqs = zip(*[self.__get_allele_seq__(w, x, y, z) for w,x,y,z in 88 | zip(cur_entries.chr, cur_entries.pos, cur_entries.allele1, cur_entries.allele2)]) 89 | 90 | if self.debug_mode: 91 | return variant_ids, list(allele1_seqs),list(allele2_seqs) 92 | else: 93 | return variant_ids, one_hot.dna_to_one_hot(list(allele1_seqs)), one_hot.dna_to_one_hot(list(allele2_seqs)) 94 | 95 | def __len__(self): 96 | return math.ceil(self.num_variants/self.batch_size) 97 | -------------------------------------------------------------------------------- /src/hitcaller_variant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import h5py 3 | import subprocess 4 | import argparse 5 | import os 6 | import pandas as pd 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description="Runs hit calling on variant-based interpretation data") 11 | parser.add_argument("--shap_data", type=str, required=True, help="h5 or npz file containing variant sequences and shap_scores") 12 | parser.add_argument("--input_type", type=str, choices=["h5", "npz"], default="h5", help="Whether the input data is in h5 or npz format") 13 | parser.add_argument("--modisco_h5", type=str, help="Modisco h5 file from relevant experiment") 14 | parser.add_argument("--variant_file", type=str, help="variant-scorer style file containing list of variants. Required if you want genomic locations as part of the final report") 15 | parser.add_argument("--hits_per_loc", type=int, help="Maximum number of hits to return per sequence per locus") 16 | parser.add_argument("--output_dir", type=str, help="Output directory") 17 | parser.add_argument("--alpha", type=float, default=0.6, help="Alpha value for hit calling") 18 | parser.add_argument("--include_motifs", type=str, help="Include motifs") 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def h5_to_npz(args): 24 | ''' 25 | If the input is given as a h5, then this function runs the relevant finemo command to convert it to a npz file 26 | Regions of width 100 are extracted, since this should be sufficient for hits containing the central variant 27 | ''' 28 | extract_command = ["finemo", "extract-regions-chrombpnet-h5", "-c", args.shap_data, "-w", "100", "-o", os.path.join(args.output_dir, "shap_input.npz")] 29 | subprocess.run(extract_command) 30 | 31 | def run_hit_calling(args, npz_file): 32 | ''' 33 | Runs hit calling given the npz file with input interpretation data 34 | ''' 35 | if args.variant_file is not None: 36 | if args.include_motifs is not None: 37 | subprocess.run(["finemo", "call-hits", "-r", npz_file, "-m", args.modisco_h5, "-o", args.output_dir, "-b", "256", "-a", str(args.alpha), "-p", os.path.join(args.output_dir, "variant_locs.narrowPeak"), "-I", args.include_motifs, "-N", args.include_motifs]) 38 | else: 39 | subprocess.run(["finemo", "call-hits", "-r", npz_file, "-m", args.modisco_h5, "-o", args.output_dir, "-b", "256", "-a", str(args.alpha), "-p", os.path.join(args.output_dir, "variant_locs.narrowPeak")]) 40 | else: 41 | if args.include_motifs is not None: 42 | subprocess.run(["finemo", "call-hits", "-r", npz_file, "-m", args.modisco_h5, "-o", args.output_dir, "-b", "256", "-a", str(args.alpha), "-I", args.include_motifs, "-N", args.include_motifs]) 43 | else: 44 | subprocess.run(["finemo", "call-hits", "-r", npz_file, "-m", args.modisco_h5, "-o", args.output_dir, "-b", "256", "-a", str(args.alpha)]) 45 | 46 | def parse_hit_calls(args): 47 | ''' 48 | Given an output file from the hit caller, identifies hits containing the central variant and returns the top n hits per sequence 49 | ''' 50 | hits_file = os.path.join(args.output_dir, "hits.tsv") 51 | hits_df = pd.read_csv(hits_file, sep="\t") 52 | # print(hits_df.head()) 53 | 54 | #Define location of variants to identify correct hits 55 | if args.variant_file is not None: 56 | variant_table = pd.read_csv(args.variant_file, sep="\t", header=None) 57 | hits_df["variant_loc"] = variant_table.loc[(hits_df["peak_id"] % len(variant_table)).astype(int), 1].values 58 | print(hits_df.head()) 59 | else: 60 | hits_df["variant_loc"] = [50] * len(hits_df) 61 | 62 | variant_hits = hits_df.loc[(hits_df["start"] <= hits_df["variant_loc"]) & (hits_df["end"] >= hits_df["variant_loc"])].copy() 63 | variant_hits["inv_coeff"] = -1 * variant_hits["hit_coefficient"] 64 | print() 65 | print(variant_hits.head()) 66 | variant_hits = variant_hits.sort_values(["peak_id", "inv_coeff"]).groupby("peak_id").head(args.hits_per_loc) 67 | if args.variant_file is not None: 68 | variant_hits['allele'] = variant_hits['peak_id'].apply(lambda x: "allele2" if x > len(variant_table) else "allele1") 69 | else: 70 | variant_hits['allele'] = "N/A" 71 | variant_out_final = variant_hits[["peak_id", "chr", "start", "end", "motif_name", "allele", 72 | "variant_loc", "hit_coefficient", "hit_correlation", "hit_importance"]] 73 | return variant_out_final 74 | 75 | 76 | def variant_file_to_narrowpeak(args): 77 | ''' 78 | Converts a variant info file (ie. the input to most variant-scorer commands) into a narrowpeak file which can be used with the hit caller 79 | ''' 80 | variant_table = pd.read_csv(args.variant_file, sep="\t", header=None) 81 | narrowpeak_raw_data = [list(variant_table[0].values), list(variant_table[1].values - 1), list(variant_table[1].values + 1), 82 | ["."] * len(variant_table), ["."] * len(variant_table), ["."] * len(variant_table), ["."] * len(variant_table), 83 | ["."] * len(variant_table), ["."] * len(variant_table), [1] * len(variant_table)] 84 | narrowpeak_df = pd.DataFrame(narrowpeak_raw_data).T 85 | narrowpeak_df = pd.concat([narrowpeak_df, narrowpeak_df]) 86 | 87 | narrowpeak_df.to_csv(os.path.join(args.output_dir, "variant_locs.narrowPeak"), sep="\t", header=False, index=False) 88 | return narrowpeak_df 89 | 90 | 91 | def main(): 92 | 93 | #Produce npz file if it does not already exist 94 | args = parse_args() 95 | if args.input_type == "npz": 96 | npz_file = args.shap_data 97 | elif args.input_type == "h5": 98 | h5_to_npz(args) 99 | npz_file = os.path.join(args.output_dir, "shap_input.npz") 100 | 101 | #Produce narrowpeak file if desired 102 | if args.variant_file is not None: 103 | npeak = variant_file_to_narrowpeak(args) 104 | 105 | #Run the hit caller and save the results 106 | run_hit_calling(args, npz_file) 107 | output_df = parse_hit_calls(args) 108 | output_df.to_csv(os.path.join(args.output_dir, "variant_hit_calls.tsv"), sep="\t", header=True, index=False) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | 114 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kundajelab/variant-scorer/0e1e34199e63112aa618748bb79a206fc491300a/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/argmanager.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def update_scoring_args(parser): 5 | parser.add_argument("-l", "--list", type=str, required=True, help="Path to TSV file containing a list of variants to score") 6 | parser.add_argument("-g", "--genome", type=str, required=True, help="Path to the genome FASTA") 7 | parser.add_argument("-pg", "--peak_genome", type=str, help="Path to the genome FASTA for peaks") 8 | parser.add_argument("-m", "--model", type=str, required=True, help="Path to the ChromBPNet model .h5 file to use for variant scoring. For most use cases, this should be the bias-corrected model (chrombpnet_nobias.h5)") 9 | parser.add_argument("-o", "--out_prefix", type=str, required=True, help="Output prefix for storing SNP effect score predictions from the script, in the form of /. Directory should already exist.") 10 | parser.add_argument("-s", "--chrom_sizes", type=str, required=True, help="Path to TSV file with chromosome sizes") 11 | parser.add_argument("-ps", "--peak_chrom_sizes", type=str, help="Path to TSV file with chromosome sizes for peak genome") 12 | parser.add_argument("-b", "--bias", type=str, help="Bias model to use for variant scoring") 13 | parser.add_argument("-li", "--lite", action='store_true', help="Models were trained with chrombpnet-lite") 14 | parser.add_argument("-dm", "--debug_mode", action='store_true', help="Display allele input sequences") 15 | parser.add_argument("-bs", "--batch_size", type=int, default=512, help="Batch size to use for the model") 16 | parser.add_argument("-sc", "--schema", type=str, choices=['bed', 'plink', 'chrombpnet', 'original'], default='chrombpnet', help="Format for the input variants TSV file") 17 | parser.add_argument("-p", "--peaks", type=str, help="Path to BED file containing peak regions") 18 | parser.add_argument("-n", "--num_shuf", type=int, default=10, help="Number of shuffled scores per SNP") 19 | parser.add_argument("-t", "--total_shuf", type=int, help="Total number of shuffled scores across all SNPs. Overrides --num_shuf") 20 | parser.add_argument("-mp", "--max_peaks", type=int, help="Maximum number of peaks to use for peak percentile calculation") 21 | parser.add_argument("-c", "--chrom", type=str, help="Only score SNPs in selected chromosome") 22 | parser.add_argument("-r", "--random_seed", type=int, default=1234, help="Random seed for reproducibility when sampling") 23 | parser.add_argument("--no_hdf5", action='store_true', help="Do not save detailed predictions in hdf5 file") 24 | parser.add_argument("-nc", "--num_chunks", type=int, default=10, help="Number of chunks to divide SNP file into") 25 | parser.add_argument("-fo", "--forward_only", action='store_true', help="Run variant scoring only on forward sequence. Default: False") 26 | parser.add_argument("-st", "--shap_type", nargs='+', default=["counts"], help="ChromBPNet output for which SHAP values should be computed ('counts' or 'profile'). Default is 'counts'") 27 | parser.add_argument("-sh", "--shuffled_scores", type=str, help="Path to pre-computed shuffled scores") 28 | parser.add_argument("--merge", action='store_true', help="For per-chromosome scoring, merge all per-chromosome predictions into a single file, and deletes the per-chromosome files. Default is False.") 29 | 30 | def fetch_scoring_args(): 31 | parser = argparse.ArgumentParser() 32 | update_scoring_args(parser) 33 | args = parser.parse_args() 34 | print(args) 35 | return args 36 | 37 | def update_shap_args(parser): 38 | parser.add_argument("-l", "--list", type=str, required=True, help="A TSV file containing a list of variants to score") 39 | parser.add_argument("-g", "--genome", type=str, required=True, help="Path to genome FASTA") 40 | parser.add_argument("-m", "--model", type=str, required=True, help="Path to the ChromBPNet model .h5 file to use for variant scoring. For most use cases, this should be the bias-corrected model (chrombpnet_nobias.h5)") 41 | parser.add_argument("-o", "--out_prefix", type=str, required=True, help="Output prefix for storing SNP effect score predictions from the script, in the form of /. Directory should already exist.") 42 | parser.add_argument("-s", "--chrom_sizes", type=str, required=True, help="Path to TSV file with chromosome sizes") 43 | parser.add_argument("-li", "--lite", action='store_true', help="Models were trained with chrombpnet-lite") 44 | parser.add_argument("-dm", "--debug_mode", action='store_true', help="Display allele input sequences") 45 | parser.add_argument("-bs", "--batch_size", type=int, default=10000, help="Batch size to use for the model") 46 | parser.add_argument("-sc", "--schema", type=str, choices=['bed', 'plink', 'chrombpnet', 'original'], default='chrombpnet', help="Format for the input variants list") 47 | parser.add_argument("-c", "--chrom", type=str, help="Only score SNPs in selected chromosome") 48 | parser.add_argument("-st", "--shap_type", nargs='+', default=["counts"]) 49 | 50 | def fetch_shap_args(): 51 | parser = argparse.ArgumentParser() 52 | update_shap_args(parser) 53 | args = parser.parse_args() 54 | print(args) 55 | return args 56 | 57 | def update_variant_summary_args(parser): 58 | parser.add_argument("-sd", "--score_dir", type=str, required=True, help="Path to directory containing variant scores that will be used to generate summary") 59 | parser.add_argument("-sl", "--score_list", nargs='+', required=True, help="Space-separated list of variant score file names that will be used to generate summary") 60 | parser.add_argument("-o", "--out_prefix", type=str, required=True, help="Output prefix for storing the summary file with average scores across folds, in the form of /. Directory should already exist.") 61 | parser.add_argument("-sc", "--schema", type=str, required=True, choices=['bed', 'plink', 'plink2', 'chrombpnet', 'original'], default='chrombpnet', help="Format for the input variants list") 62 | 63 | def fetch_variant_summary_args(): 64 | parser = argparse.ArgumentParser() 65 | update_variant_summary_args(parser) 66 | args = parser.parse_args() 67 | print(args) 68 | return args 69 | 70 | def update_variant_annotation_args(parser): 71 | parser.add_argument( 72 | "-l", "--list", type=str, required=True, 73 | help=( 74 | "Path to TSV file containing the variant scores (or summarized scores) to annotate.\n" 75 | "Alternatively, provide a BED file of variants with --schema bed.\n" 76 | "The file should contain variant information compatible with the selected schema." 77 | ) 78 | ) 79 | parser.add_argument("-o", "--out_prefix", type=str, required=True, help="Output prefix for storing the annotated file, in the form of /. Directory should already exist.") 80 | parser.add_argument("-p", "--peaks", type=str, help="Path to BED file containing peak regions") 81 | parser.add_argument("--hits", type=str, help="Path to BED file containing motif hits regions") 82 | parser.add_argument("-ge", "--genes", type=str, help="Path to BED file containing gene regions") 83 | parser.add_argument("-sc", "--schema", type=str, required=False, choices=['bed', 'plink', 'plink2', 'chrombpnet', 'original'], default='chrombpnet', help="Format for the input variants list") 84 | 85 | def fetch_variant_annotation_args(): 86 | parser = argparse.ArgumentParser() 87 | update_variant_annotation_args(parser) 88 | args = parser.parse_args() 89 | print(args) 90 | 91 | # Assert that at least one of genes, peaks, or hits is provided 92 | if not args.genes and not args.peaks and not args.hits: 93 | print("Error: At least one of --genes, --peaks, or --hits must be provided for annotation.") 94 | parser.print_help() 95 | exit(1) 96 | 97 | return args 98 | -------------------------------------------------------------------------------- /src/utils/helpers.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.utils import get_custom_objects 2 | from tensorflow.keras.models import load_model 3 | import tensorflow as tf 4 | from scipy.spatial.distance import jensenshannon 5 | import pandas as pd 6 | import numpy as np 7 | from tqdm import tqdm 8 | import sys 9 | sys.path.append('..') 10 | from generators.variant_generator import VariantGenerator 11 | from generators.peak_generator import PeakGenerator 12 | from utils import losses 13 | from utils.io import get_variant_schema, get_peak_schema 14 | 15 | 16 | def get_valid_peaks(chrom, pos, summit, input_len, chrom_sizes_dict): 17 | valid_chrom = chrom in chrom_sizes_dict 18 | if valid_chrom: 19 | flank = input_len // 2 20 | lower_check = ((pos + summit) - flank > 0) 21 | upper_check = ((pos + summit) + flank <= chrom_sizes_dict[chrom]) 22 | in_bounds = lower_check and upper_check 23 | valid_peak = valid_chrom and in_bounds 24 | return valid_peak 25 | else: 26 | return False 27 | 28 | def get_valid_variants(chrom, pos, allele1, allele2, input_len, chrom_sizes_dict): 29 | valid_chrom = chrom in chrom_sizes_dict 30 | if valid_chrom: 31 | flank = input_len // 2 32 | lower_check = (pos - flank > 0) 33 | upper_check = (pos + flank <= chrom_sizes_dict[chrom]) 34 | in_bounds = lower_check and upper_check 35 | # no_allele1_indel = (len(allele1) == 1) 36 | # no_allele2_indel = (len(allele2) == 1) 37 | # no_indel = no_allele1_indel and no_allele2_indel 38 | # valid_variant = valid_chrom and in_bounds and no_indel 39 | valid_variant = valid_chrom and in_bounds 40 | return valid_variant 41 | else: 42 | return False 43 | 44 | def softmax(x, temp=1): 45 | norm_x = x - np.mean(x, axis=1, keepdims=True) 46 | return np.exp(temp*norm_x)/np.sum(np.exp(temp*norm_x), axis=1, keepdims=True) 47 | 48 | def load_model_wrapper(model_file): 49 | # read .h5 model 50 | custom_objects = {"multinomial_nll": losses.multinomial_nll, "tf": tf} 51 | get_custom_objects().update(custom_objects) 52 | model = load_model(model_file, compile=False) 53 | print("model loaded succesfully") 54 | return model 55 | 56 | def fetch_peak_predictions(model, peaks, input_len, genome_fasta, batch_size, debug_mode=False, lite=False,forward_only=False): 57 | peak_ids = [] 58 | pred_counts = [] 59 | pred_profiles = [] 60 | if not forward_only: 61 | revcomp_counts = [] 62 | revcomp_profiles = [] 63 | 64 | # peak sequence generator 65 | peak_gen = PeakGenerator(peaks=peaks, 66 | input_len=input_len, 67 | genome_fasta=genome_fasta, 68 | batch_size=batch_size, 69 | debug_mode=debug_mode) 70 | 71 | for i in tqdm(range(len(peak_gen))): 72 | batch_peak_ids, seqs = peak_gen[i] 73 | revcomp_seq = seqs[:, ::-1, ::-1] 74 | 75 | if lite: 76 | batch_preds = model.predict([seqs, 77 | np.zeros((len(seqs), model.output_shape[0][1])), 78 | np.zeros((len(seqs), ))], 79 | verbose=False) 80 | 81 | if not forward_only: 82 | revcomp_batch_preds = model.predict([revcomp_seq, 83 | np.zeros((len(revcomp_seq), model.output_shape[0][1])), 84 | np.zeros((len(revcomp_seq), ))], 85 | verbose=False) 86 | else: 87 | batch_preds = model.predict(seqs, verbose=False) 88 | if not forward_only: 89 | revcomp_batch_preds = model.predict(revcomp_seq, verbose=False) 90 | 91 | batch_preds[1] = np.array([batch_preds[1][i] for i in range(len(batch_preds[1]))]) 92 | pred_counts.extend(np.exp(batch_preds[1])) 93 | pred_profiles.extend(np.array(batch_preds[0])) # np.squeeze(softmax()) to get probability profile 94 | 95 | if not forward_only: 96 | revcomp_batch_preds[1] = np.array([revcomp_batch_preds[1][i] for i in range(len(revcomp_batch_preds[1]))]) 97 | revcomp_counts.extend(np.exp(revcomp_batch_preds[1])) 98 | revcomp_profiles.extend(np.array(revcomp_batch_preds[0])) # np.squeeze(softmax()) to get probability profile 99 | 100 | peak_ids.extend(batch_peak_ids) 101 | 102 | peak_ids = np.array(peak_ids) 103 | pred_counts = np.array(pred_counts) 104 | pred_profiles = np.array(pred_profiles) 105 | 106 | if not forward_only: 107 | revcomp_counts = np.array(revcomp_counts) 108 | revcomp_profiles = np.array(revcomp_profiles) 109 | average_counts = np.average([pred_counts,revcomp_counts],axis=0) 110 | average_profiles = np.average([pred_profiles,revcomp_profiles[:,::-1]],axis=0) 111 | return peak_ids,average_counts,average_profiles 112 | else: 113 | return peak_ids,pred_counts,pred_profiles 114 | 115 | def fetch_variant_predictions(model, variants_table, input_len, genome_fasta, batch_size, debug_mode=False, lite=False, shuf=False, forward_only=False): 116 | variant_ids = [] 117 | allele1_pred_counts = [] 118 | allele2_pred_counts = [] 119 | allele1_pred_profiles = [] 120 | allele2_pred_profiles = [] 121 | if not forward_only: 122 | revcomp_allele1_pred_counts = [] 123 | revcomp_allele2_pred_counts = [] 124 | revcomp_allele1_pred_profiles = [] 125 | revcomp_allele2_pred_profiles = [] 126 | 127 | # variant sequence generator 128 | var_gen = VariantGenerator(variants_table=variants_table, 129 | input_len=input_len, 130 | genome_fasta=genome_fasta, 131 | batch_size=batch_size, 132 | debug_mode=False, 133 | shuf=shuf) 134 | 135 | for i in tqdm(range(len(var_gen))): 136 | 137 | batch_variant_ids, allele1_seqs, allele2_seqs = var_gen[i] 138 | revcomp_allele1_seqs = allele1_seqs[:, ::-1, ::-1] 139 | revcomp_allele2_seqs = allele2_seqs[:, ::-1, ::-1] 140 | 141 | if lite: 142 | allele1_batch_preds = model.predict([allele1_seqs, 143 | np.zeros((len(allele1_seqs), model.output_shape[0][1])), 144 | np.zeros((len(allele1_seqs), ))], 145 | verbose=False) 146 | allele2_batch_preds = model.predict([allele2_seqs, 147 | np.zeros((len(allele2_seqs), model.output_shape[0][1])), 148 | np.zeros((len(allele2_seqs), ))], 149 | verbose=False) 150 | 151 | if not forward_only: 152 | revcomp_allele1_batch_preds = model.predict([revcomp_allele1_seqs, 153 | np.zeros((len(revcomp_allele1_seqs), model.output_shape[0][1])), 154 | np.zeros((len(revcomp_allele1_seqs), ))], 155 | verbose=False) 156 | revcomp_allele2_batch_preds = model.predict([revcomp_allele2_seqs, 157 | np.zeros((len(revcomp_allele2_seqs), model.output_shape[0][1])), 158 | np.zeros((len(revcomp_allele2_seqs), ))], 159 | verbose=False) 160 | else: 161 | allele1_batch_preds = model.predict(allele1_seqs, verbose=False) 162 | allele2_batch_preds = model.predict(allele2_seqs, verbose=False) 163 | if not forward_only: 164 | revcomp_allele1_batch_preds = model.predict(revcomp_allele1_seqs, verbose=False) 165 | revcomp_allele2_batch_preds = model.predict(revcomp_allele2_seqs, verbose=False) 166 | 167 | allele1_batch_preds[1] = np.array([allele1_batch_preds[1][i] for i in range(len(allele1_batch_preds[1]))]) 168 | allele2_batch_preds[1] = np.array([allele2_batch_preds[1][i] for i in range(len(allele2_batch_preds[1]))]) 169 | allele1_pred_counts.extend(np.exp(allele1_batch_preds[1])) 170 | allele2_pred_counts.extend(np.exp(allele2_batch_preds[1])) 171 | allele1_pred_profiles.extend(np.array(allele1_batch_preds[0])) # np.squeeze(softmax()) to get probability profile 172 | allele2_pred_profiles.extend(np.array(allele2_batch_preds[0])) 173 | 174 | if not forward_only: 175 | revcomp_allele1_batch_preds[1] = np.array([revcomp_allele1_batch_preds[1][i] for i in range(len(revcomp_allele1_batch_preds[1]))]) 176 | revcomp_allele2_batch_preds[1] = np.array([revcomp_allele2_batch_preds[1][i] for i in range(len(revcomp_allele2_batch_preds[1]))]) 177 | revcomp_allele1_pred_counts.extend(np.exp(revcomp_allele1_batch_preds[1])) 178 | revcomp_allele2_pred_counts.extend(np.exp(revcomp_allele2_batch_preds[1])) 179 | revcomp_allele1_pred_profiles.extend(np.array(revcomp_allele1_batch_preds[0])) # np.squeeze(softmax()) to get probability profile 180 | revcomp_allele2_pred_profiles.extend(np.array(revcomp_allele2_batch_preds[0])) 181 | 182 | variant_ids.extend(batch_variant_ids) 183 | 184 | variant_ids = np.array(variant_ids) 185 | allele1_pred_counts = np.array(allele1_pred_counts) 186 | allele2_pred_counts = np.array(allele2_pred_counts) 187 | allele1_pred_profiles = np.array(allele1_pred_profiles) 188 | allele2_pred_profiles = np.array(allele2_pred_profiles) 189 | 190 | if not forward_only: 191 | revcomp_allele1_pred_counts = np.array(revcomp_allele1_pred_counts) 192 | revcomp_allele2_pred_counts = np.array(revcomp_allele2_pred_counts) 193 | revcomp_allele1_pred_profiles = np.array(revcomp_allele1_pred_profiles) 194 | revcomp_allele2_pred_profiles = np.array(revcomp_allele2_pred_profiles) 195 | average_allele1_pred_counts = np.average([allele1_pred_counts,revcomp_allele1_pred_counts],axis=0) 196 | average_allele1_pred_profiles = np.average([allele1_pred_profiles,revcomp_allele1_pred_profiles[:,::-1]],axis=0) 197 | average_allele2_pred_counts = np.average([allele2_pred_counts,revcomp_allele2_pred_counts],axis=0) 198 | average_allele2_pred_profiles = np.average([allele2_pred_profiles,revcomp_allele2_pred_profiles[:,::-1]],axis=0) 199 | return variant_ids, average_allele1_pred_counts, average_allele2_pred_counts, \ 200 | average_allele1_pred_profiles, average_allele2_pred_profiles 201 | else: 202 | return variant_ids, allele1_pred_counts, allele2_pred_counts, \ 203 | allele1_pred_profiles, allele2_pred_profiles 204 | 205 | def get_variant_scores_with_peaks(allele1_pred_counts, allele2_pred_counts, 206 | allele1_pred_profiles, allele2_pred_profiles, pred_counts): 207 | # logfc = np.log2(allele2_pred_counts / allele1_pred_counts) 208 | # jsd = np.array([jensenshannon(x,y,base=2.0) for x,y in zip(allele2_pred_profiles, allele1_pred_profiles)]) 209 | 210 | logfc, jsd = get_variant_scores(allele1_pred_counts, allele2_pred_counts, 211 | allele1_pred_profiles, allele2_pred_profiles) 212 | allele1_quantile = np.array([np.max([np.mean(pred_counts < x), (1/len(pred_counts))]) for x in allele1_pred_counts]) 213 | allele2_quantile = np.array([np.max([np.mean(pred_counts < x), (1/len(pred_counts))]) for x in allele2_pred_counts]) 214 | 215 | return logfc, jsd, allele1_quantile, allele2_quantile 216 | 217 | def get_variant_scores(allele1_pred_counts, allele2_pred_counts, 218 | allele1_pred_profiles, allele2_pred_profiles): 219 | 220 | print('allele1_pred_counts shape:', allele1_pred_counts.shape) 221 | print('allele2_pred_counts shape:', allele2_pred_counts.shape) 222 | print('allele1_pred_profiles shape:', allele1_pred_profiles.shape) 223 | print('allele2_pred_profiles shape:', allele2_pred_profiles.shape) 224 | 225 | logfc = np.squeeze(np.log2(allele2_pred_counts / allele1_pred_counts)) 226 | jsd = np.squeeze([jensenshannon(x, y, base=2.0) 227 | for x,y in zip(softmax(allele2_pred_profiles), 228 | softmax(allele1_pred_profiles))]) 229 | 230 | print('logfc shape:', logfc.shape) 231 | print('jsd shape:', jsd.shape) 232 | 233 | return logfc, jsd 234 | 235 | def adjust_indel_jsd(variants_table,allele1_pred_profiles,allele2_pred_profiles,original_jsd): 236 | allele1_pred_profiles = softmax(allele1_pred_profiles) 237 | allele2_pred_profiles = softmax(allele2_pred_profiles) 238 | indel_idx = [] 239 | for i, row in variants_table.iterrows(): 240 | allele1, allele2 = row[['allele1','allele2']] 241 | if allele1 == "-": 242 | allele1 = "" 243 | if allele2 == "-": 244 | allele2 = "" 245 | if len(allele1) != len(allele2): 246 | indel_idx += [i] 247 | 248 | adjusted_jsd = [] 249 | for i in indel_idx: 250 | row = variants_table.iloc[i] 251 | allele1, allele2 = row[['allele1','allele2']] 252 | if allele1 == "-": 253 | allele1 = "" 254 | if allele2 == "-": 255 | allele2 = "" 256 | 257 | allele1_length = len(allele1) 258 | allele2_length = len(allele2) 259 | 260 | allele1_p = allele1_pred_profiles[i] 261 | allele2_p = allele2_pred_profiles[i] 262 | assert len(allele1_p) == len(allele2_p) 263 | assert allele1_length != allele2_length 264 | flank_size = len(allele1_p)//2 265 | allele1_left_flank = allele1_p[:flank_size] 266 | allele2_left_flank = allele2_p[:flank_size] 267 | 268 | if allele1_length > allele2_length: 269 | allele1_right_flank = np.concatenate([allele1_p[flank_size:flank_size+allele2_length],allele1_p[flank_size+allele1_length:]]) 270 | allele2_right_flank = allele2_p[flank_size:allele2_length-allele1_length] 271 | else: 272 | allele1_right_flank = allele1_p[flank_size:allele1_length-allele2_length] 273 | allele2_right_flank = np.concatenate([allele2_p[flank_size:flank_size+allele1_length], allele2_p[flank_size+allele2_length:]]) 274 | 275 | 276 | adjusted_allele1_p = np.concatenate([allele1_left_flank,allele1_right_flank]) 277 | adjusted_allele2_p = np.concatenate([allele2_left_flank,allele2_right_flank]) 278 | adjusted_allele1_p = adjusted_allele1_p/np.sum(adjusted_allele1_p) 279 | adjusted_allele2_p = adjusted_allele2_p/np.sum(adjusted_allele2_p) 280 | assert len(adjusted_allele1_p) == len(adjusted_allele2_p) 281 | adjusted_j = jensenshannon(adjusted_allele1_p,adjusted_allele2_p,base=2.0) 282 | adjusted_jsd += [adjusted_j] 283 | 284 | adjusted_jsd_list = original_jsd.copy() 285 | if len(indel_idx) > 0: 286 | for i in range(len(indel_idx)): 287 | idx = indel_idx[i] 288 | adjusted_jsd_list[idx] = adjusted_jsd[i] 289 | 290 | return indel_idx, adjusted_jsd_list 291 | 292 | 293 | def create_shuffle_table(variants_table, random_seed=None, total_shuf=None, num_shuf=None): 294 | if total_shuf != None: 295 | if len(variants_table) > total_shuf: 296 | shuf_variants_table = variants_table.sample(total_shuf, 297 | random_state=random_seed, 298 | ignore_index=True, 299 | replace=False) 300 | else: 301 | shuf_variants_table = variants_table.sample(total_shuf, 302 | random_state=random_seed, 303 | ignore_index=True, 304 | replace=True) 305 | shuf_variants_table['random_seed'] = np.random.permutation(len(shuf_variants_table)) 306 | else: 307 | if num_shuf != None: 308 | total_shuf = len(variants_table) * num_shuf 309 | shuf_variants_table = variants_table.sample(total_shuf, 310 | random_state=random_seed, 311 | ignore_index=True, 312 | replace=True) 313 | shuf_variants_table['random_seed'] = np.random.permutation(len(shuf_variants_table)) 314 | else: 315 | ## empty dataframe 316 | shuf_variants_table = pd.DataFrame() 317 | return shuf_variants_table 318 | 319 | def get_pvals(obs, bg, tail): 320 | sorted_bg = np.sort(bg) 321 | if tail == 'right' or tail == 'both': 322 | rank_right = len(sorted_bg) - np.searchsorted(sorted_bg, obs, side='left') 323 | pval_right = (rank_right + 1) / (len(sorted_bg) + 1) 324 | if tail == 'right': 325 | return pval_right 326 | if tail == 'left' or tail == 'both': 327 | rank_left = np.searchsorted(sorted_bg, obs, side='right') 328 | pval_left = (rank_left + 1) / (len(sorted_bg) + 1) 329 | if tail == 'left': 330 | return pval_left 331 | assert tail == 'both' 332 | min_pval = np.minimum(pval_left, pval_right) 333 | pval_both = min_pval * 2 334 | 335 | return pval_both 336 | 337 | -------------------------------------------------------------------------------- /src/utils/io.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import pybedtools 4 | from scipy.spatial.distance import jensenshannon 5 | 6 | 7 | def get_variant_schema(schema): 8 | var_SCHEMA = {'original': ['chr', 'pos', 'variant_id', 'allele1', 'allele2'], 9 | 'plink': ['chr', 'variant_id', 'ignore1', 'pos', 'allele1', 'allele2'], 10 | 'plink2': ['chr', 'variant_id', 'pos', 'allele1', 'allele2'], 11 | 'bed': ['chr', 'pos', 'end', 'allele1', 'allele2', 'variant_id'], 12 | 'chrombpnet': ['chr', 'pos', 'allele1', 'allele2', 'variant_id']} 13 | return var_SCHEMA[schema] 14 | 15 | 16 | def get_peak_schema(schema): 17 | PEAK_SCHEMA = {'narrowpeak': ['chr', 'start', 'end', 'peak_id', 'peak_score', 18 | 5, 6, 7, 'rank', 'summit']} 19 | return PEAK_SCHEMA[schema] 20 | 21 | 22 | def validate_alleles(variants_table): 23 | """Validate that alleles contain only valid nucleotides (ACGT) or deletion marker (-)""" 24 | valid_chars = set('ACGT-') 25 | 26 | for col in ['allele1', 'allele2']: 27 | if col in variants_table.columns: 28 | for idx, allele in enumerate(variants_table[col]): 29 | 30 | allele_str = str(allele).upper() 31 | 32 | if not set(allele_str).issubset(valid_chars): 33 | raise ValueError(f"Invalid characters in {col} at row {idx}: '{allele}'. Only A, C, G, T, and - are allowed.") 34 | 35 | # If the allele contains "-", it should be a single character 36 | if '-' in allele_str and len(allele_str) > 1: 37 | raise ValueError(f"Invalid allele at row {idx}: '{allele}'. Use a single '-' to represent INDELs.") 38 | 39 | 40 | def load_variant_table(table_path, schema): 41 | # Read file first to check structure 42 | temp_df = pd.read_csv(table_path, header=None, sep='\t', nrows=5) 43 | expected_cols = len(get_variant_schema(schema)) 44 | 45 | if temp_df.shape[1] != expected_cols: 46 | raise ValueError(f"File has {temp_df.shape[1]} columns but {schema} schema expects {expected_cols} columns") 47 | 48 | variants_table = pd.read_csv(table_path, header=None, sep='\t', names=get_variant_schema(schema)) 49 | variants_table.drop(columns=[str(x) for x in variants_table.columns if str(x).startswith('ignore')], inplace=True) 50 | variants_table['chr'] = variants_table['chr'].astype(str) 51 | has_chr_prefix = any('chr' in x.lower() for x in variants_table['chr'].tolist()) 52 | if not has_chr_prefix: 53 | variants_table['chr'] = 'chr' + variants_table['chr'] 54 | if schema == "bed": 55 | # Convert to 1-based indexing 56 | variants_table['pos'] = variants_table['pos'] + 1 57 | 58 | # Validate alleles 59 | validate_alleles(variants_table) 60 | 61 | return variants_table 62 | 63 | 64 | def add_missing_columns_to_peaks_df(peaks, schema): 65 | if schema != 'narrowpeak': 66 | raise ValueError("Schema not supported") 67 | 68 | required_columns = get_peak_schema(schema) 69 | num_current_columns = peaks.shape[1] 70 | 71 | if num_current_columns == 10: 72 | peaks.columns = required_columns[:num_current_columns] 73 | return peaks # No missing columns, return as is 74 | 75 | elif num_current_columns < 3: 76 | raise ValueError("Peaks dataframe has fewer than 3 columns, which is invalid") 77 | 78 | elif num_current_columns > 10: 79 | raise ValueError("Peaks dataframe has greater than 10 columns, which is invalid") 80 | 81 | # Add missing columns to reach a total of 10 columns 82 | peaks.columns = required_columns[:num_current_columns] 83 | columns_to_add = required_columns[num_current_columns:] 84 | 85 | for column in columns_to_add: 86 | peaks[column] = '.' 87 | 88 | # Calculate the summit column 89 | peaks['summit'] = (peaks['end'] - peaks['start']) // 2 90 | 91 | return peaks 92 | 93 | 94 | def load_genes(genes_file): 95 | """Load genes from file and return as pybedtools BedTool object.""" 96 | 97 | gene_df = pd.read_table(genes_file, header=None) 98 | print(gene_df.head()) 99 | gene_bed = pybedtools.BedTool.from_dataframe(gene_df) 100 | 101 | return gene_bed 102 | 103 | 104 | def load_peaks(peaks_file): 105 | """Load peaks from file and return as pybedtools BedTool object.""" 106 | 107 | peak_df = pd.read_table(peaks_file, header=None) 108 | print(peak_df.head()) 109 | peak_bed = pybedtools.BedTool.from_dataframe(peak_df) 110 | 111 | return peak_bed -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | 4 | #from https://github.com/kundajelab/basepair/blob/cda0875571066343cdf90aed031f7c51714d991a/basepair/losses.py#L87 5 | def multinomial_nll(true_counts, logits): 6 | """Compute the multinomial negative log-likelihood 7 | Args: 8 | true_counts: observed count values 9 | logits: predicted logit values 10 | """ 11 | counts_per_example = tf.reduce_sum(true_counts, axis=-1) 12 | dist = tfp.distributions.Multinomial(total_count=counts_per_example, 13 | logits=logits) 14 | return (-tf.reduce_sum(dist.log_prob(true_counts)) / 15 | tf.cast(tf.shape(true_counts)[0], dtype=tf.float32)) 16 | 17 | -------------------------------------------------------------------------------- /src/utils/one_hot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Written by Alex Tseng 3 | 4 | https://gist.github.com/amtseng/010dd522daaabc92b014f075a34a0a0b 5 | """ 6 | 7 | import numpy as np 8 | 9 | def dna_to_one_hot(seqs): 10 | """ 11 | Converts a list of DNA ("ACGT") sequences to one-hot encodings, where the 12 | position of 1s is ordered alphabetically by "ACGT". `seqs` must be a list 13 | of N strings, where every string is the same length L. Returns an N x L x 4 14 | NumPy array of one-hot encodings, in the same order as the input sequences. 15 | All bases will be converted to upper-case prior to performing the encoding. 16 | Any bases that are not "ACGT" will be given an encoding of all 0s. 17 | """ 18 | seq_len = len(seqs[0]) 19 | assert np.all(np.array([len(s) for s in seqs]) == seq_len) 20 | 21 | # Join all sequences together into one long string, all uppercase 22 | seq_concat = "".join(seqs).upper() + "ACGT" 23 | # Add one example of each base, so np.unique doesn't miss indices later 24 | 25 | one_hot_map = np.identity(5)[:, :-1].astype(np.int8) 26 | 27 | # Convert string into array of ASCII character codes; 28 | base_vals = np.frombuffer(bytearray(seq_concat, "utf8"), dtype=np.int8) 29 | 30 | # Anything that's not an A, C, G, or T gets assigned a higher code 31 | base_vals[~np.isin(base_vals, np.array([65, 67, 71, 84]))] = 85 32 | 33 | # Convert the codes into indices in [0, 4], in ascending order by code 34 | _, base_inds = np.unique(base_vals, return_inverse=True) 35 | 36 | # Get the one-hot encoding for those indices, and reshape back to separate 37 | return one_hot_map[base_inds[:-4]].reshape((len(seqs), seq_len, 4)) 38 | 39 | 40 | def one_hot_to_dna(one_hot): 41 | """ 42 | Converts a one-hot encoding into a list of DNA ("ACGT") sequences, where the 43 | position of 1s is ordered alphabetically by "ACGT". `one_hot` must be an 44 | N x L x 4 array of one-hot encodings. Returns a lits of N "ACGT" strings, 45 | each of length L, in the same order as the input array. The returned 46 | sequences will only consist of letters "A", "C", "G", "T", or "N" (all 47 | upper-case). Any encodings that are all 0s will be translated to "N". 48 | """ 49 | bases = np.array(["A", "C", "G", "T", "N"]) 50 | # Create N x L array of all 5s 51 | one_hot_inds = np.tile(one_hot.shape[2], one_hot.shape[:2]) 52 | 53 | # Get indices of where the 1s are 54 | batch_inds, seq_inds, base_inds = np.where(one_hot) 55 | 56 | # In each of the locations in the N x L array, fill in the location of the 1 57 | one_hot_inds[batch_inds, seq_inds] = base_inds 58 | 59 | # Fetch the corresponding base for each position using indexing 60 | seq_array = bases[one_hot_inds] 61 | return ["".join(seq) for seq in seq_array] 62 | -------------------------------------------------------------------------------- /src/utils/shap_utils.py: -------------------------------------------------------------------------------- 1 | # Av's code with a bit of reformatting 2 | # Adapted from Zahoor's mtbatchgen 3 | 4 | from tensorflow.keras.utils import get_custom_objects 5 | from tensorflow.keras.models import load_model 6 | import tensorflow as tf 7 | import scipy.stats 8 | from scipy.spatial.distance import jensenshannon 9 | import pandas as pd 10 | import os 11 | import argparse 12 | import numpy as np 13 | import h5py 14 | import math 15 | from tqdm import tqdm 16 | import sys 17 | sys.path.append('..') 18 | from generators.variant_generator import VariantGenerator 19 | from generators.peak_generator import PeakGenerator 20 | from utils import argmanager, losses 21 | import shap 22 | from deeplift.dinuc_shuffle import dinuc_shuffle 23 | tf.compat.v1.disable_v2_behavior() 24 | 25 | 26 | def combine_mult_and_diffref(mult, orig_inp, bg_data): 27 | to_return = [] 28 | 29 | for l in [0]: 30 | projected_hypothetical_contribs = \ 31 | np.zeros_like(bg_data[l]).astype("float") 32 | assert len(orig_inp[l].shape)==2 33 | 34 | # At each position in the input sequence, we iterate over the 35 | # one-hot encoding possibilities (eg: for genomic sequence, 36 | # this is ACGT i.e. 1000, 0100, 0010 and 0001) and compute the 37 | # hypothetical difference-from-reference in each case. We then 38 | # multiply the hypothetical differences-from-reference with 39 | # the multipliers to get the hypothetical contributions. For 40 | # each of the one-hot encoding possibilities, the hypothetical 41 | # contributions are then summed across the ACGT axis to 42 | # estimate the total hypothetical contribution of each 43 | # position. This per-position hypothetical contribution is then 44 | # assigned ("projected") onto whichever base was present in the 45 | # hypothetical sequence. The reason this is a fast estimate of 46 | # what the importance scores *would* look like if different 47 | # bases were present in the underlying sequence is that the 48 | # multipliers are computed once using the original sequence, 49 | # and are not computed again for each hypothetical sequence. 50 | for i in range(orig_inp[l].shape[-1]): 51 | hypothetical_input = np.zeros_like(orig_inp[l]).astype("float") 52 | hypothetical_input[:, i] = 1.0 53 | hypothetical_difference_from_reference = \ 54 | (hypothetical_input[None, :, :] - bg_data[l]) 55 | hypothetical_contribs = hypothetical_difference_from_reference * \ 56 | mult[l] 57 | projected_hypothetical_contribs[:, :, i] = \ 58 | np.sum(hypothetical_contribs, axis=-1) 59 | 60 | to_return.append(np.mean(projected_hypothetical_contribs,axis=0)) 61 | 62 | if len(orig_inp)>1: 63 | to_return.append(np.zeros_like(orig_inp[1])) 64 | 65 | return to_return 66 | 67 | 68 | def shuffle_several_times(s): 69 | numshuffles=20 70 | if len(s)==2: 71 | return [np.array([dinuc_shuffle(s[0]) for i in range(numshuffles)]), 72 | np.array([s[1] for i in range(numshuffles)])] 73 | else: 74 | return [np.array([dinuc_shuffle(s[0]) for i in range(numshuffles)])] 75 | 76 | 77 | def get_weightedsum_meannormed_logits(model): 78 | # See Google slide deck for explanations 79 | # We meannorm as per section titled 80 | # "Adjustments for Softmax Layers" in the DeepLIFT paper 81 | meannormed_logits = (model.outputs[0] - \ 82 | tf.reduce_mean(model.outputs[0], axis=1)[:, None]) 83 | 84 | # 'stop_gradient' will prevent importance from being propagated 85 | # through this operation; we do this because we just want to treat 86 | # the post-softmax probabilities as 'weights' on the different 87 | # logits, without having the network explain how the probabilities 88 | # themselves were derived. Could be worth contrasting explanations 89 | # derived with and without stop_gradient enabled... 90 | stopgrad_meannormed_logits = tf.stop_gradient(meannormed_logits) 91 | softmax_out = tf.nn.softmax(stopgrad_meannormed_logits, axis=1) 92 | 93 | # Weight the logits according to the softmax probabilities, take 94 | # the sum for each example. This mirrors what was done for the 95 | # bpnet paper. 96 | weightedsum_meannormed_logits = tf.reduce_sum(softmax_out * \ 97 | meannormed_logits, 98 | axis=1) 99 | 100 | return weightedsum_meannormed_logits 101 | 102 | 103 | def fetch_shap(model, variants_table, input_len, genome_fasta, batch_size, debug_mode=False, lite=False, bias=None, shuf=False,shap_type="counts"): 104 | variant_ids = [] 105 | allele1_counts_shap = [] 106 | allele2_counts_shap = [] 107 | allele1_profile_shap = [] 108 | allele2_profile_shap = [] 109 | allele1_inputs = [] 110 | allele2_inputs = [] 111 | 112 | # variant sequence generator 113 | var_gen = VariantGenerator(variants_table=variants_table, 114 | input_len=input_len, 115 | genome_fasta=genome_fasta, 116 | batch_size=batch_size, 117 | debug_mode=False, 118 | shuf=shuf) 119 | 120 | for i in tqdm(range(len(var_gen))): 121 | 122 | batch_variant_ids, allele1_seqs, allele2_seqs = var_gen[i] 123 | 124 | if lite: 125 | if shap_type == "counts": 126 | counts_model_input = [model.input[0], model.input[2]] 127 | allele1_input = [allele1_seqs, np.zeros((allele1_seqs.shape[0], 1))] 128 | allele2_input = [allele2_seqs, np.zeros((allele2_seqs.shape[0], 1))] 129 | 130 | profile_model_counts_explainer = shap.explainers.deep.TFDeepExplainer( 131 | (counts_model_input, tf.reduce_sum(model.outputs[1], axis=-1)), 132 | shuffle_several_times, 133 | combine_mult_and_diffref=combine_mult_and_diffref) 134 | 135 | allele1_counts_shap_batch = profile_model_counts_explainer.shap_values( 136 | allele1_input, progress_message=10) 137 | allele2_counts_shap_batch = profile_model_counts_explainer.shap_values( 138 | allele2_input, progress_message=10) 139 | 140 | allele1_counts_shap_batch = allele1_counts_shap_batch[0] * allele1_input[0] 141 | allele2_counts_shap_batch = allele2_counts_shap_batch[0] * allele2_input[0] 142 | 143 | allele1_counts_shap.extend(allele1_counts_shap_batch) 144 | allele2_counts_shap.extend(allele2_counts_shap_batch) 145 | 146 | else: 147 | assert shap_type == "profile" 148 | profile_model_input = [model.input[0], model.input[1]] 149 | outlen = model.output_shape[0][1] 150 | 151 | allele1_input = [allele1_seqs, np.zeros((allele1_seqs.shape[0], outlen))] 152 | allele2_input = [allele2_seqs, np.zeros((allele2_seqs.shape[0], outlen))] 153 | 154 | weightedsum_meannormed_logits = get_weightedsum_meannormed_logits(model) 155 | profile_model_profile_explainer = shap.explainers.deep.TFDeepExplainer( 156 | (profile_model_input, weightedsum_meannormed_logits), 157 | shuffle_several_times, 158 | combine_mult_and_diffref=combine_mult_and_diffref) 159 | 160 | allele1_profile_shap_batch = profile_model_profile_explainer.shap_values( 161 | allele1_input, progress_message=10) 162 | allele2_profile_shap_batch = profile_model_profile_explainer.shap_values( 163 | allele2_input, progress_message=10) 164 | 165 | allele1_profile_shap_batch = allele1_profile_shap_batch[0] * allele1_input[0] 166 | allele2_profile_shap_batch = allele2_profile_shap_batch[0] * allele2_input[0] 167 | 168 | allele1_profile_shap.extend(allele1_profile_shap_batch) 169 | allele2_profile_shap.extend(allele2_profile_shap_batch) 170 | 171 | else: 172 | allele1_input = allele1_seqs 173 | allele2_input = allele2_seqs 174 | 175 | if shap_type == "counts": 176 | counts_model_input = model.input 177 | profile_model_counts_explainer = shap.explainers.deep.TFDeepExplainer( 178 | (counts_model_input, tf.reduce_sum(model.outputs[1], axis=-1)), 179 | shuffle_several_times, 180 | combine_mult_and_diffref=combine_mult_and_diffref) 181 | 182 | allele1_counts_shap_batch = profile_model_counts_explainer.shap_values( 183 | allele1_input, progress_message=10) 184 | allele2_counts_shap_batch = profile_model_counts_explainer.shap_values( 185 | allele2_input, progress_message=10) 186 | 187 | # allele1_counts_shap_batch = allele1_counts_shap_batch * allele1_input 188 | # allele2_counts_shap_batch = allele2_counts_shap_batch * allele2_input 189 | 190 | allele1_counts_shap.extend(allele1_counts_shap_batch) 191 | allele2_counts_shap.extend(allele2_counts_shap_batch) 192 | 193 | allele1_inputs.extend(allele1_input) 194 | allele2_inputs.extend(allele2_input) 195 | 196 | else: 197 | assert shap_type == "profile" 198 | profile_model_input = model.input 199 | weightedsum_meannormed_logits = get_weightedsum_meannormed_logits(model) 200 | profile_model_profile_explainer = shap.explainers.deep.TFDeepExplainer( 201 | (profile_model_input, weightedsum_meannormed_logits), 202 | shuffle_several_times, 203 | combine_mult_and_diffref=combine_mult_and_diffref) 204 | 205 | allele1_profile_shap_batch = profile_model_profile_explainer.shap_values( 206 | allele1_input, progress_message=10) 207 | allele2_profile_shap_batch = profile_model_profile_explainer.shap_values( 208 | allele2_input, progress_message=10) 209 | 210 | # allele1_profile_shap_batch = allele1_profile_shap_batch * allele1_input 211 | # allele2_profile_shap_batch = allele2_profile_shap_batch * allele2_input 212 | 213 | allele1_profile_shap.extend(allele1_profile_shap_batch) 214 | allele2_profile_shap.extend(allele2_profile_shap_batch) 215 | 216 | allele1_inputs.extend(allele1_input) 217 | allele2_inputs.extend(allele2_input) 218 | 219 | variant_ids.extend(batch_variant_ids) 220 | 221 | if shap_type == "counts": 222 | return np.array(variant_ids), np.array(allele1_inputs), np.array(allele2_inputs), \ 223 | np.array(allele1_counts_shap), np.array(allele2_counts_shap) 224 | else: 225 | return np.array(variant_ids), np.array(allele1_inputs), np.array(allele2_inputs), \ 226 | np.array(allele1_profile_shap), np.array(allele2_profile_shap) 227 | -------------------------------------------------------------------------------- /src/variant_annotation.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pybedtools 3 | from utils.argmanager import * 4 | from utils.io import * 5 | pd.set_option('display.max_columns', 20) 6 | 7 | 8 | def main(): 9 | args = fetch_variant_annotation_args() 10 | print(args) 11 | variant_scores_file = args.list 12 | output_prefix = args.out_prefix 13 | peak_path = args.peaks 14 | hits_path = args.hits 15 | genes = args.genes 16 | 17 | print("Loading variant scores from:", variant_scores_file) 18 | variant_scores = pd.read_table(variant_scores_file) 19 | 20 | if args.schema == "bed": 21 | if variant_scores['pos'].equals(variant_scores['end']): 22 | variant_scores['pos'] = variant_scores['pos'] - 1 23 | variant_scores_bed_format = variant_scores[['chr','pos','end','allele1','allele2','variant_id']].copy() 24 | variant_scores_bed_format.sort_values(by=["chr","pos","end"], inplace=True) 25 | else: 26 | # otherwise, assuming output of variant scoring pipeline. 27 | # convert to bed format 28 | try: 29 | variant_scores_bed_format = variant_scores[['chr','pos','allele1','allele2','variant_id']].copy() 30 | except KeyError: 31 | print("Error: The input file does not contain the required columns:" \ 32 | "'chr', 'pos', 'allele1', 'allele2', 'variant_id'. " \ 33 | "Please provide scores output of variant_scoring.py (or summarized scores) as input.") 34 | return 35 | variant_scores_bed_format['pos'] = variant_scores_bed_format.apply(lambda x: int(x.pos)-1, axis = 1) 36 | variant_scores_bed_format['end'] = variant_scores_bed_format.apply(lambda x: int(x.pos)+len(x.allele1), axis = 1) 37 | variant_scores_bed_format = variant_scores_bed_format[['chr','pos','end','allele1','allele2','variant_id']] 38 | variant_scores_bed_format.sort_values(by=["chr","pos","end"], inplace=True) 39 | 40 | print() 41 | print(variant_scores_bed_format.head()) 42 | print("Variants table shape:", variant_scores_bed_format.shape) 43 | print() 44 | 45 | variant_bed = pybedtools.BedTool.from_dataframe(variant_scores_bed_format) 46 | 47 | # Process overlaps between variants and provided genes 48 | if args.genes: 49 | print("annotating with closest genes") 50 | 51 | gene_bed = load_genes(genes) 52 | 53 | closest_genes_bed = variant_bed.closest(gene_bed, d=True, t='first', k=3) 54 | 55 | closest_gene_df = closest_genes_bed.to_dataframe(header=None) 56 | 57 | print() 58 | print(closest_gene_df.head()) 59 | print("Closest genes table shape:", closest_gene_df.shape) 60 | print() 61 | 62 | closest_genes = {} 63 | gene_dists = {} 64 | 65 | for index, row in closest_gene_df.iterrows(): 66 | if not row[5] in closest_genes: 67 | closest_genes[row[5]] = [] 68 | gene_dists[row[5]] = [] 69 | closest_genes[row[5]].append(row.iloc[9]) 70 | gene_dists[row[5]].append(row.iloc[-1]) 71 | 72 | closest_gene_df = closest_gene_df.rename({5: 'variant_id'}, axis=1) 73 | closest_gene_df = closest_gene_df[['variant_id']] 74 | closest_gene_df['closest_gene_1'] = closest_gene_df['variant_id'].apply(lambda x: closest_genes[x][0] if len(closest_genes[x]) > 0 else '.') 75 | closest_gene_df['gene_distance_1'] = closest_gene_df['variant_id'].apply(lambda x: gene_dists[x][0] if len(closest_genes[x]) > 0 else '.') 76 | 77 | closest_gene_df['closest_gene_2'] = closest_gene_df['variant_id'].apply(lambda x: closest_genes[x][1] if len(closest_genes[x]) > 1 else '.') 78 | closest_gene_df['gene_distance_2'] = closest_gene_df['variant_id'].apply(lambda x: gene_dists[x][1] if len(closest_genes[x]) > 1 else '.') 79 | 80 | closest_gene_df['closest_gene_3'] = closest_gene_df['variant_id'].apply(lambda x: closest_genes[x][2] if len(closest_genes[x]) > 2 else '.') 81 | closest_gene_df['gene_distance_3'] = closest_gene_df['variant_id'].apply(lambda x: gene_dists[x][2] if len(closest_genes[x]) > 2 else '.') 82 | 83 | closest_gene_df = closest_gene_df[['variant_id', 'closest_gene_1', 'gene_distance_1', 84 | 'closest_gene_2', 'gene_distance_2', 85 | 'closest_gene_3', 'gene_distance_3']] 86 | closest_gene_df.drop_duplicates(inplace=True) 87 | variant_scores = variant_scores.merge(closest_gene_df, on='variant_id', how='left') 88 | 89 | # Process overlaps between variants and provided peak regions 90 | if args.peaks: 91 | print("annotating with peak overlap") 92 | 93 | peak_bed = load_peaks(peak_path) 94 | 95 | peak_intersect_bed = variant_bed.intersect(peak_bed, wa=True, u=True) 96 | 97 | peak_intersect_df = peak_intersect_bed.to_dataframe(names=variant_scores_bed_format.columns.tolist()) 98 | 99 | print() 100 | print(peak_intersect_df.head()) 101 | print("Peak overlap table shape:", peak_intersect_df.shape) 102 | print() 103 | 104 | # If non-empty 105 | if not peak_intersect_df.empty: 106 | variant_scores['peak_overlap'] = variant_scores['variant_id'].isin(peak_intersect_df['variant_id'].tolist()) 107 | else: 108 | # add empty column if no peaks overlap found 109 | variant_scores['peak_overlap'] = False 110 | print("No peaks overlap found.") 111 | 112 | # Process overlaps between variants and provided motif hits 113 | if args.hits: 114 | print("annotating with motif hits overlap") 115 | hits_df = pd.read_table(hits_path, header=None) 116 | 117 | # set column names 118 | hits_df.columns = ['chr_hit', 'start_hit', 'end_hit', 'motif', 'score', 'strand', 'class'] 119 | 120 | print(hits_df.head()) 121 | hits_bed = pybedtools.BedTool.from_dataframe(hits_df) 122 | hits_intersect_bed = variant_bed.intersect(hits_bed, wo=True) 123 | print(hits_intersect_bed.head()) 124 | 125 | hits_intersect_df = hits_intersect_bed.to_dataframe(names=variant_scores_bed_format.columns.tolist() + hits_df.columns.tolist() + ['overlap_length']) 126 | 127 | print() 128 | print("Motif hits overlap table shape:", hits_intersect_df.shape) 129 | print() 130 | 131 | # If non-empty 132 | if not hits_intersect_df.empty: 133 | print(hits_intersect_df.head()) 134 | 135 | # Make a boolean column indicating if the variant overlaps with motif hits 136 | variant_scores['hits_overlap'] = variant_scores['variant_id'].isin(hits_intersect_df['variant_id'].tolist()) 137 | 138 | # Collapse the list of motif names for each variant 139 | hits_intersect_df['hits_motifs'] = hits_intersect_df.groupby('variant_id')['motif'].transform(lambda x: ','.join(set(x))) 140 | hits_intersect_df = hits_intersect_df[['variant_id', 'hits_motifs']].drop_duplicates() 141 | variant_scores = variant_scores.merge(hits_intersect_df, on='variant_id', how='left') 142 | variant_scores['hits_motifs'] = variant_scores['hits_motifs'].fillna('-') 143 | 144 | else: 145 | # add empty column if no hits overlap found 146 | variant_scores['hits_overlap'] = False 147 | variant_scores['hits_motifs'] = '-' 148 | print("No motif hits overlap found.") 149 | 150 | print() 151 | print(variant_scores.head()) 152 | print("Annotation table shape:", variant_scores.shape) 153 | print() 154 | 155 | # Print some summary statistics: 156 | if args.peaks: 157 | print("Number of variants overlapping peaks:", variant_scores['peak_overlap'].sum(), "/", variant_scores.shape[0]) 158 | 159 | if args.hits: 160 | print("Number of variants overlapping motif hits:", variant_scores['hits_overlap'].sum(), "/", variant_scores.shape[0]) 161 | 162 | out_file = output_prefix + ".annotations.tsv" 163 | variant_scores.to_csv(out_file, sep="\t", index=False) 164 | 165 | print("DONE") 166 | print() 167 | 168 | 169 | if __name__ == "__main__": 170 | main() -------------------------------------------------------------------------------- /src/variant_scoring.per_chrom.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | import h5py 5 | from utils import argmanager 6 | from utils.helpers import * 7 | from utils.io import * 8 | 9 | 10 | def main(): 11 | args = argmanager.fetch_scoring_args() 12 | print(args) 13 | 14 | np.random.seed(args.random_seed) 15 | if args.forward_only: 16 | print("running variant scoring only for forward sequences") 17 | 18 | out_dir = os.path.sep.join(args.out_prefix.split(os.path.sep)[:-1]) 19 | if not os.path.exists(out_dir): 20 | raise OSError("Output directory does not exist") 21 | 22 | # load the model and variants 23 | model = load_model_wrapper(args.model) 24 | variants_table = load_variant_table(args.list, args.schema) 25 | variants_table = variants_table.fillna('-') 26 | 27 | chrom_sizes = pd.read_csv(args.chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 28 | chrom_sizes_dict = chrom_sizes.set_index('chrom')['size'].to_dict() 29 | 30 | print("Original variants table shape:", variants_table.shape) 31 | 32 | if args.chrom: 33 | variants_table = variants_table.loc[variants_table['chr'] == args.chrom] 34 | print("Chromosome variants table shape:", variants_table.shape) 35 | 36 | # infer input length 37 | if args.lite: 38 | input_len = model.input_shape[0][1] 39 | else: 40 | input_len = model.input_shape[1] 41 | 42 | print("Input length inferred from the model:", input_len) 43 | 44 | variants_table = variants_table.loc[variants_table.apply(lambda x: get_valid_variants(x.chr, x.pos, x.allele1, x.allele2, input_len, chrom_sizes_dict), axis=1)] 45 | variants_table.reset_index(drop=True, inplace=True) 46 | 47 | print("Final variants table shape:", variants_table.shape) 48 | 49 | if args.shuffled_scores: 50 | shuf_variants_table = pd.read_table(args.shuffled_scores) 51 | print("Shuffled variants table shape:", shuf_variants_table.shape) 52 | shuf_scores_file = args.shuffled_scores 53 | 54 | else: 55 | shuf_variants_table = create_shuffle_table(variants_table, args.random_seed, args.total_shuf, args.num_shuf) 56 | print("Shuffled variants table shape:", shuf_variants_table.shape) 57 | shuf_scores_file = '.'.join([args.out_prefix, "variant_scores.shuffled.tsv"]) 58 | 59 | peak_scores_file = '.'.join([args.out_prefix, "peak_scores.tsv"]) 60 | 61 | if len(shuf_variants_table) > 0: 62 | if args.debug_mode: 63 | shuf_variants_table = shuf_variants_table.sample(10000, random_state=args.random_seed, ignore_index=True) 64 | print() 65 | print(shuf_variants_table.head()) 66 | print("Debug shuffled variants table shape:", shuf_variants_table.shape) 67 | print() 68 | 69 | shuf_variants_done = False 70 | if os.path.isfile(shuf_scores_file): 71 | shuf_variants_table_loaded = pd.read_table(shuf_scores_file) 72 | if shuf_variants_table_loaded['variant_id'].tolist() == shuf_variants_table['variant_id'].tolist(): 73 | shuf_variants_table = shuf_variants_table_loaded.copy() 74 | shuf_variants_done = True 75 | 76 | if not shuf_variants_done: 77 | shuf_variant_ids, shuf_allele1_pred_counts, shuf_allele2_pred_counts, \ 78 | shuf_allele1_pred_profiles, shuf_allele2_pred_profiles = fetch_variant_predictions(model, 79 | shuf_variants_table, 80 | input_len, 81 | args.genome, 82 | args.batch_size, 83 | debug_mode=args.debug_mode, 84 | lite=args.lite, 85 | shuf=True, 86 | forward_only=args.forward_only) 87 | assert np.array_equal(shuf_variants_table["variant_id"].tolist(), shuf_variant_ids) 88 | shuf_variants_table["allele1_pred_counts"] = shuf_allele1_pred_counts 89 | shuf_variants_table["allele2_pred_counts"] = shuf_allele2_pred_counts 90 | 91 | if args.peaks: 92 | if args.peak_chrom_sizes == None: 93 | args.peak_chrom_sizes = args.chrom_sizes 94 | if args.peak_genome == None: 95 | args.peak_genome = args.genome 96 | 97 | peak_chrom_sizes = pd.read_csv(args.peak_chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 98 | peak_chrom_sizes_dict = peak_chrom_sizes.set_index('chrom')['size'].to_dict() 99 | 100 | peaks = pd.read_csv(args.peaks, header=None, sep='\t') 101 | peaks = add_missing_columns_to_peaks_df(peaks, schema='narrowpeak') 102 | peaks['peak_id'] = peaks['chr'] + ':' + peaks['start'].astype(str) + '-' + peaks['end'].astype(str) 103 | 104 | print("Original peak table shape:", peaks.shape) 105 | 106 | peaks.sort_values(by=['chr', 'start', 'end', 'summit', 'rank'], ascending=[True, True, True, True, False], inplace=True) 107 | peaks.drop_duplicates(subset=['chr', 'start', 'end', 'summit'], inplace=True) 108 | peaks = peaks.loc[peaks.apply(lambda x: get_valid_peaks(x.chr, x.start, x.summit, input_len, peak_chrom_sizes_dict), axis=1)] 109 | peaks.reset_index(drop=True, inplace=True) 110 | 111 | print("De-duplicated peak table shape:", peaks.shape) 112 | 113 | if args.debug_mode: 114 | peaks = peaks.sample(10000, random_state=args.random_seed, ignore_index=True) 115 | print() 116 | print(peaks.head()) 117 | print("Debug peak table shape:", peaks.shape) 118 | print() 119 | 120 | if args.max_peaks: 121 | if len(peaks) > args.max_peaks: 122 | peaks = peaks.sample(args.max_peaks, random_state=args.random_seed, ignore_index=True) 123 | print("Subsampled peak table shape:", peaks.shape) 124 | 125 | peak_scores_done = False 126 | if os.path.isfile(peak_scores_file): 127 | peaks_loaded = pd.read_table(peak_scores_file) 128 | if peaks_loaded['peak_id'].tolist() == peaks['peak_id'].tolist(): 129 | peaks = peaks_loaded.copy() 130 | peak_scores_done = True 131 | 132 | if not peak_scores_done: 133 | peak_ids, peak_pred_counts, peak_pred_profiles = fetch_peak_predictions(model, 134 | peaks, 135 | input_len, 136 | args.peak_genome, 137 | args.batch_size, 138 | debug_mode=args.debug_mode, 139 | lite=args.lite, 140 | forward_only=args.forward_only) 141 | assert np.array_equal(peaks["peak_id"].tolist(), peak_ids) 142 | peaks["peak_score"] = peak_pred_counts 143 | print() 144 | print(peaks.head()) 145 | print("Peak score table shape:", peaks.shape) 146 | print() 147 | peaks.to_csv(peak_scores_file, sep="\t", index=False) 148 | 149 | if len(shuf_variants_table) > 0 and not shuf_variants_done: 150 | shuf_logfc, shuf_jsd, \ 151 | shuf_allele1_quantile, shuf_allele2_quantile = get_variant_scores_with_peaks(shuf_allele1_pred_counts, 152 | shuf_allele2_pred_counts, 153 | shuf_allele1_pred_profiles, 154 | shuf_allele2_pred_profiles, 155 | np.array(peaks["peak_score"].tolist())) 156 | shuf_indel_idx, shuf_adjusted_jsd_list = adjust_indel_jsd(shuf_variants_table, 157 | shuf_allele1_pred_profiles, 158 | shuf_allele2_pred_profiles, 159 | shuf_jsd) 160 | shuf_has_indel_variants = (len(shuf_indel_idx) > 0) 161 | 162 | shuf_variants_table["logfc"] = shuf_logfc 163 | shuf_variants_table["abs_logfc"] = np.abs(shuf_logfc) 164 | if shuf_has_indel_variants: 165 | shuf_variants_table["jsd"] = shuf_adjusted_jsd_list 166 | else: 167 | shuf_variants_table["jsd"] = shuf_jsd 168 | assert np.array_equal(shuf_adjusted_jsd_list, shuf_jsd) 169 | shuf_variants_table['original_jsd'] = shuf_jsd 170 | shuf_variants_table["logfc_x_jsd"] = shuf_variants_table["logfc"] * shuf_variants_table["jsd"] 171 | shuf_variants_table["abs_logfc_x_jsd"] = shuf_variants_table["abs_logfc"] * shuf_variants_table["jsd"] 172 | 173 | shuf_variants_table["allele1_quantile"] = shuf_allele1_quantile 174 | shuf_variants_table["allele2_quantile"] = shuf_allele2_quantile 175 | shuf_variants_table["active_allele_quantile"] = shuf_variants_table[["allele1_quantile", "allele2_quantile"]].max(axis=1) 176 | shuf_variants_table["quantile_change"] = shuf_variants_table["allele2_quantile"] - shuf_variants_table["allele1_quantile"] 177 | shuf_variants_table["abs_quantile_change"] = np.abs(shuf_variants_table["quantile_change"]) 178 | shuf_variants_table["logfc_x_active_allele_quantile"] = shuf_variants_table["logfc"] * shuf_variants_table["active_allele_quantile"] 179 | shuf_variants_table["abs_logfc_x_active_allele_quantile"] = shuf_variants_table["abs_logfc"] * shuf_variants_table["active_allele_quantile"] 180 | shuf_variants_table["jsd_x_active_allele_quantile"] = shuf_variants_table["jsd"] * shuf_variants_table["active_allele_quantile"] 181 | shuf_variants_table["logfc_x_jsd_x_active_allele_quantile"] = shuf_variants_table["logfc_x_jsd"] * shuf_variants_table["active_allele_quantile"] 182 | shuf_variants_table["abs_logfc_x_jsd_x_active_allele_quantile"] = shuf_variants_table["abs_logfc_x_jsd"] * shuf_variants_table["active_allele_quantile"] 183 | 184 | assert shuf_variants_table["abs_logfc"].shape == shuf_logfc.shape 185 | assert shuf_variants_table["abs_logfc"].shape == shuf_jsd.shape 186 | assert shuf_variants_table["abs_logfc"].shape == shuf_variants_table["abs_logfc_x_jsd"].shape 187 | 188 | print() 189 | print(shuf_variants_table.head()) 190 | print("Shuffled score table shape:", shuf_variants_table.shape) 191 | print() 192 | shuf_variants_table.to_csv(shuf_scores_file, sep="\t", index=False) 193 | 194 | else: 195 | if len(shuf_variants_table) > 0 and not shuf_variants_done: 196 | shuf_logfc, shuf_jsd = get_variant_scores(shuf_allele1_pred_counts, 197 | shuf_allele2_pred_counts, 198 | shuf_allele1_pred_profiles, 199 | shuf_allele2_pred_profiles) 200 | 201 | shuf_indel_idx, shuf_adjusted_jsd_list = adjust_indel_jsd(shuf_variants_table, 202 | shuf_allele1_pred_profiles, 203 | shuf_allele2_pred_profiles, 204 | shuf_jsd) 205 | shuf_has_indel_variants = (len(shuf_indel_idx) > 0) 206 | 207 | shuf_variants_table["logfc"] = shuf_logfc 208 | shuf_variants_table["abs_logfc"] = np.abs(shuf_logfc) 209 | if shuf_has_indel_variants: 210 | shuf_variants_table["jsd"] = shuf_adjusted_jsd_list 211 | else: 212 | shuf_variants_table["jsd"] = shuf_jsd 213 | assert np.array_equal(shuf_adjusted_jsd_list, shuf_jsd) 214 | shuf_variants_table['original_jsd'] = shuf_jsd 215 | shuf_variants_table["logfc_x_jsd"] = shuf_variants_table["logfc"] * shuf_variants_table["jsd"] 216 | shuf_variants_table["abs_logfc_x_jsd"] = shuf_variants_table["abs_logfc"] * shuf_variants_table["jsd"] 217 | 218 | assert shuf_variants_table["abs_logfc"].shape == shuf_logfc.shape 219 | assert shuf_variants_table["abs_logfc"].shape == shuf_jsd.shape 220 | assert shuf_variants_table["abs_logfc"].shape == shuf_variants_table["abs_logfc_x_jsd"].shape 221 | 222 | print() 223 | print(shuf_variants_table.head()) 224 | print("Shuffled score table shape:", shuf_variants_table.shape) 225 | print() 226 | shuf_variants_table.to_csv(shuf_scores_file, sep="\t", index=False) 227 | 228 | todo_chroms = [x for x in variants_table.chr.unique()] 229 | 230 | for chrom in todo_chroms: 231 | print() 232 | print(chrom) 233 | print() 234 | 235 | chrom_variants_table = variants_table.loc[variants_table['chr'] == chrom].sort_values(by='pos').copy() 236 | chrom_variants_table.reset_index(drop=True, inplace=True) 237 | 238 | chrom_scores_done = False 239 | chrom_scores_file = '.'.join([args.out_prefix, str(chrom), "variant_scores.tsv"]) 240 | if os.path.isfile(chrom_scores_file): 241 | chrom_variants_table_loaded = pd.read_table(chrom_scores_file) 242 | if chrom_variants_table_loaded['variant_id'].tolist() == chrom_variants_table['variant_id'].tolist(): 243 | 244 | print("Loaded existing chrom scores file:", chrom_scores_file) 245 | 246 | chrom_scores_done = True 247 | 248 | if not chrom_scores_done: 249 | print(str(chrom) + " variants table shape:", chrom_variants_table.shape) 250 | print() 251 | 252 | if args.debug_mode: 253 | chrom_variants_table = chrom_variants_table.sample(10000, random_state=args.random_seed, ignore_index=True) 254 | print() 255 | print(chrom_variants_table.head()) 256 | print("Debug variants table shape:", chrom_variants_table.shape) 257 | print() 258 | 259 | # fetch model prediction for variants 260 | variant_ids, allele1_pred_counts, allele2_pred_counts, \ 261 | allele1_pred_profiles, allele2_pred_profiles = fetch_variant_predictions(model, 262 | chrom_variants_table, 263 | input_len, 264 | args.genome, 265 | args.batch_size, 266 | debug_mode=args.debug_mode, 267 | lite=args.lite, 268 | shuf=False, 269 | forward_only=args.forward_only) 270 | 271 | if args.peaks: 272 | logfc, jsd, \ 273 | allele1_quantile, allele2_quantile = get_variant_scores_with_peaks(allele1_pred_counts, 274 | allele2_pred_counts, 275 | allele1_pred_profiles, 276 | allele2_pred_profiles, 277 | np.array(peaks["peak_score"].tolist())) 278 | 279 | else: 280 | logfc, jsd = get_variant_scores(allele1_pred_counts, 281 | allele2_pred_counts, 282 | allele1_pred_profiles, 283 | allele2_pred_profiles) 284 | 285 | indel_idx, adjusted_jsd_list = adjust_indel_jsd(chrom_variants_table,allele1_pred_profiles,allele2_pred_profiles,jsd) 286 | has_indel_variants = (len(indel_idx) > 0) 287 | 288 | assert np.array_equal(chrom_variants_table["variant_id"].tolist(), variant_ids) 289 | chrom_variants_table["allele1_pred_counts"] = allele1_pred_counts 290 | chrom_variants_table["allele2_pred_counts"] = allele2_pred_counts 291 | chrom_variants_table["logfc"] = logfc 292 | chrom_variants_table["abs_logfc"] = np.abs(chrom_variants_table["logfc"]) 293 | if has_indel_variants: 294 | chrom_variants_table["jsd"] = adjusted_jsd_list 295 | else: 296 | chrom_variants_table["jsd"] = jsd 297 | assert np.array_equal(adjusted_jsd_list, jsd) 298 | chrom_variants_table["original_jsd"] = jsd 299 | chrom_variants_table["logfc_x_jsd"] = chrom_variants_table["logfc"] * chrom_variants_table["jsd"] 300 | chrom_variants_table["abs_logfc_x_jsd"] = chrom_variants_table["abs_logfc"] * chrom_variants_table["jsd"] 301 | 302 | if len(shuf_variants_table) > 0: 303 | chrom_variants_table["logfc.pval"] = get_pvals(chrom_variants_table["logfc"].tolist(), shuf_variants_table["logfc"], tail="both") 304 | chrom_variants_table["abs_logfc.pval"] = get_pvals(chrom_variants_table["abs_logfc"].tolist(), shuf_variants_table["abs_logfc"], tail="right") 305 | chrom_variants_table["jsd.pval"] = get_pvals(chrom_variants_table["jsd"].tolist(), shuf_variants_table["jsd"], tail="right") 306 | chrom_variants_table["logfc_x_jsd.pval"] = get_pvals(chrom_variants_table["logfc_x_jsd"].tolist(), shuf_variants_table["logfc_x_jsd"], tail="both") 307 | chrom_variants_table["abs_logfc_x_jsd.pval"] = get_pvals(chrom_variants_table["abs_logfc_x_jsd"].tolist(), shuf_variants_table["abs_logfc_x_jsd"], tail="right") 308 | if args.peaks: 309 | chrom_variants_table["allele1_quantile"] = allele1_quantile 310 | chrom_variants_table["allele2_quantile"] = allele2_quantile 311 | chrom_variants_table["active_allele_quantile"] = chrom_variants_table[["allele1_quantile", "allele2_quantile"]].max(axis=1) 312 | chrom_variants_table["quantile_change"] = chrom_variants_table["allele2_quantile"] - chrom_variants_table["allele1_quantile"] 313 | chrom_variants_table["abs_quantile_change"] = np.abs(chrom_variants_table["quantile_change"]) 314 | chrom_variants_table["logfc_x_active_allele_quantile"] = chrom_variants_table["logfc"] * chrom_variants_table["active_allele_quantile"] 315 | chrom_variants_table["abs_logfc_x_active_allele_quantile"] = chrom_variants_table["abs_logfc"] * chrom_variants_table["active_allele_quantile"] 316 | chrom_variants_table["jsd_x_active_allele_quantile"] = chrom_variants_table["jsd"] * chrom_variants_table["active_allele_quantile"] 317 | chrom_variants_table["logfc_x_jsd_x_active_allele_quantile"] = chrom_variants_table["logfc_x_jsd"] * chrom_variants_table["active_allele_quantile"] 318 | chrom_variants_table["abs_logfc_x_jsd_x_active_allele_quantile"] = chrom_variants_table["abs_logfc_x_jsd"] * chrom_variants_table["active_allele_quantile"] 319 | 320 | if len(shuf_variants_table) > 0: 321 | chrom_variants_table["active_allele_quantile.pval"] = get_pvals(chrom_variants_table["active_allele_quantile"].tolist(), 322 | shuf_variants_table["active_allele_quantile"], tail="right") 323 | chrom_variants_table['quantile_change.pval'] = get_pvals(chrom_variants_table["quantile_change"].tolist(), 324 | shuf_variants_table["quantile_change"], tail="both") 325 | chrom_variants_table["abs_quantile_change.pval"] = get_pvals(chrom_variants_table["abs_quantile_change"].tolist(), 326 | shuf_variants_table["abs_quantile_change"], tail="right") 327 | chrom_variants_table["logfc_x_active_allele_quantile.pval"] = get_pvals(chrom_variants_table["logfc_x_active_allele_quantile"].tolist(), 328 | shuf_variants_table["logfc_x_active_allele_quantile"], tail="both") 329 | chrom_variants_table["abs_logfc_x_active_allele_quantile.pval"] = get_pvals(chrom_variants_table["abs_logfc_x_active_allele_quantile"].tolist(), 330 | shuf_variants_table["abs_logfc_x_active_allele_quantile"], tail="right") 331 | chrom_variants_table["jsd_x_active_allele_quantile.pval"] = get_pvals(chrom_variants_table["jsd_x_active_allele_quantile"].tolist(), 332 | shuf_variants_table["jsd_x_active_allele_quantile"], tail="right") 333 | chrom_variants_table["logfc_x_jsd_x_active_allele_quantile.pval"] = get_pvals(chrom_variants_table["logfc_x_jsd_x_active_allele_quantile"].tolist(), 334 | shuf_variants_table["logfc_x_jsd_x_active_allele_quantile"], tail="both") 335 | chrom_variants_table["abs_logfc_x_jsd_x_active_allele_quantile.pval"] = get_pvals(chrom_variants_table["abs_logfc_x_jsd_x_active_allele_quantile"].tolist(), 336 | shuf_variants_table["abs_logfc_x_jsd_x_active_allele_quantile"], tail="right") 337 | 338 | if args.schema == "bed": 339 | chrom_variants_table['pos'] = chrom_variants_table['pos'] - 1 340 | 341 | # store predictions at variants 342 | if not args.no_hdf5: 343 | with h5py.File('.'.join([args.out_prefix, chrom, "variant_predictions.h5"]), 'w') as f: 344 | observed = f.create_group('observed') 345 | observed.create_dataset('allele1_pred_counts', data=allele1_pred_counts, compression='gzip', compression_opts=9) 346 | observed.create_dataset('allele2_pred_counts', data=allele2_pred_counts, compression='gzip', compression_opts=9) 347 | observed.create_dataset('allele1_pred_profiles', data=allele1_pred_profiles, compression='gzip', compression_opts=9) 348 | observed.create_dataset('allele2_pred_profiles', data=allele2_pred_profiles, compression='gzip', compression_opts=9) 349 | 350 | print() 351 | print(chrom_variants_table.head()) 352 | print("Output " + str(chrom) + " score table shape:", chrom_variants_table.shape) 353 | print() 354 | chrom_variants_table.to_csv(chrom_scores_file, sep="\t", index=False) 355 | 356 | # merge all per-chromosome predictions if requested 357 | if args.merge: 358 | print("Merging all per-chromosome predictions...") 359 | merged_dfs = [] 360 | 361 | for chrom in todo_chroms: 362 | chrom_scores_file = '.'.join([args.out_prefix, str(chrom), "variant_scores.tsv"]) 363 | if os.path.isfile(chrom_scores_file): 364 | chrom_df = pd.read_table(chrom_scores_file) 365 | merged_dfs.append(chrom_df) 366 | print(f"Added {chrom}: {len(chrom_df)} variants") 367 | 368 | # remove the per-chromosome scores file 369 | print(f"Removing {chrom_scores_file}...") 370 | os.remove(chrom_scores_file) 371 | else: 372 | print(f"Warning: {chrom_scores_file} not found, skipping") 373 | 374 | if merged_dfs: 375 | merged_df = pd.concat(merged_dfs, ignore_index=True) 376 | merged_file = '.'.join([args.out_prefix, "variant_scores.tsv"]) 377 | merged_df.to_csv(merged_file, sep="\t", index=False) 378 | print(f"Merged scores for {len(merged_dfs)} chromosomes into {merged_file}") 379 | print(f"Total variants: {len(merged_df)}") 380 | else: 381 | print("No chromosome files found to merge") 382 | 383 | print("DONE") 384 | print() 385 | 386 | 387 | if __name__ == "__main__": 388 | main() 389 | -------------------------------------------------------------------------------- /src/variant_scoring.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | import h5py 5 | from utils import argmanager 6 | from utils.helpers import * 7 | from utils.io import * 8 | 9 | 10 | def main(): 11 | args = argmanager.fetch_scoring_args() 12 | print(args) 13 | 14 | np.random.seed(args.random_seed) 15 | if args.forward_only: 16 | print("running variant scoring only for forward sequences") 17 | 18 | out_dir = os.path.sep.join(args.out_prefix.split(os.path.sep)[:-1]) 19 | if not os.path.exists(out_dir): 20 | raise OSError("Output directory does not exist") 21 | 22 | # load the model and variants 23 | model = load_model_wrapper(args.model) 24 | variants_table = load_variant_table(args.list, args.schema) 25 | variants_table = variants_table.fillna('-') 26 | 27 | chrom_sizes = pd.read_csv(args.chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 28 | chrom_sizes_dict = chrom_sizes.set_index('chrom')['size'].to_dict() 29 | 30 | print("Original variants table shape:", variants_table.shape) 31 | 32 | if args.chrom: 33 | variants_table = variants_table.loc[variants_table['chr'] == args.chrom] 34 | print("Chromosome variants table shape:", variants_table.shape) 35 | 36 | # infer input length 37 | if args.lite: 38 | input_len = model.input_shape[0][1] 39 | else: 40 | input_len = model.input_shape[1] 41 | 42 | print("Input length inferred from the model:", input_len) 43 | 44 | variants_table = variants_table.loc[variants_table.apply(lambda x: get_valid_variants(x.chr, x.pos, x.allele1, x.allele2, input_len, chrom_sizes_dict), axis=1)] 45 | variants_table.reset_index(drop=True, inplace=True) 46 | 47 | print("Final variants table shape:", variants_table.shape) 48 | 49 | if args.shuffled_scores: 50 | shuf_variants_table = pd.read_table(args.shuffled_scores) 51 | print("Shuffled variants table shape:", shuf_variants_table.shape) 52 | shuf_scores_file = args.shuffled_scores 53 | 54 | else: 55 | shuf_variants_table = create_shuffle_table(variants_table, args.random_seed, args.total_shuf, args.num_shuf) 56 | print("Shuffled variants table shape:", shuf_variants_table.shape) 57 | shuf_scores_file = '.'.join([args.out_prefix, "variant_scores.shuffled.tsv"]) 58 | 59 | peak_scores_file = '.'.join([args.out_prefix, "peak_scores.tsv"]) 60 | 61 | if len(shuf_variants_table) > 0: 62 | if args.debug_mode: 63 | shuf_variants_table = shuf_variants_table.sample(10000, random_state=args.random_seed, ignore_index=True) 64 | print() 65 | print(shuf_variants_table.head()) 66 | print("Debug shuffled variants table shape:", shuf_variants_table.shape) 67 | print() 68 | 69 | shuf_variants_done = False 70 | if os.path.isfile(shuf_scores_file): 71 | shuf_variants_table_loaded = pd.read_table(shuf_scores_file) 72 | if shuf_variants_table_loaded['variant_id'].tolist() == shuf_variants_table['variant_id'].tolist(): 73 | shuf_variants_table = shuf_variants_table_loaded.copy() 74 | shuf_variants_done = True 75 | 76 | if not shuf_variants_done: 77 | shuf_variant_ids, shuf_allele1_pred_counts, shuf_allele2_pred_counts, \ 78 | shuf_allele1_pred_profiles, shuf_allele2_pred_profiles = fetch_variant_predictions(model, 79 | shuf_variants_table, 80 | input_len, 81 | args.genome, 82 | args.batch_size, 83 | debug_mode=args.debug_mode, 84 | lite=args.lite, 85 | shuf=True, 86 | forward_only=args.forward_only) 87 | assert np.array_equal(shuf_variants_table["variant_id"].tolist(), shuf_variant_ids) 88 | shuf_variants_table["allele1_pred_counts"] = shuf_allele1_pred_counts 89 | shuf_variants_table["allele2_pred_counts"] = shuf_allele2_pred_counts 90 | 91 | if args.peaks: 92 | if args.peak_chrom_sizes == None: 93 | args.peak_chrom_sizes = args.chrom_sizes 94 | if args.peak_genome == None: 95 | args.peak_genome = args.genome 96 | 97 | peak_chrom_sizes = pd.read_csv(args.peak_chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 98 | peak_chrom_sizes_dict = peak_chrom_sizes.set_index('chrom')['size'].to_dict() 99 | 100 | peaks = pd.read_csv(args.peaks, header=None, sep='\t') 101 | peaks = add_missing_columns_to_peaks_df(peaks, schema='narrowpeak') 102 | peaks['peak_id'] = peaks['chr'] + ':' + peaks['start'].astype(str) + '-' + peaks['end'].astype(str) 103 | 104 | print("Original peak table shape:", peaks.shape) 105 | 106 | peaks.sort_values(by=['chr', 'start', 'end', 'summit', 'rank'], ascending=[True, True, True, True, False], inplace=True) 107 | peaks.drop_duplicates(subset=['chr', 'start', 'end', 'summit'], inplace=True) 108 | peaks = peaks.loc[peaks.apply(lambda x: get_valid_peaks(x.chr, x.start, x.summit, input_len, peak_chrom_sizes_dict), axis=1)] 109 | peaks.reset_index(drop=True, inplace=True) 110 | 111 | print("De-duplicated peak table shape:", peaks.shape) 112 | 113 | if args.debug_mode: 114 | peaks = peaks.sample(10000, random_state=args.random_seed, ignore_index=True) 115 | print() 116 | print(peaks.head()) 117 | print("Debug peak table shape:", peaks.shape) 118 | print() 119 | 120 | if args.max_peaks: 121 | if len(peaks) > args.max_peaks: 122 | peaks = peaks.sample(args.max_peaks, random_state=args.random_seed, ignore_index=True) 123 | print("Subsampled peak table shape:", peaks.shape) 124 | 125 | peak_scores_done = False 126 | if os.path.isfile(peak_scores_file): 127 | peaks_loaded = pd.read_table(peak_scores_file) 128 | if peaks_loaded['peak_id'].tolist() == peaks['peak_id'].tolist(): 129 | peaks = peaks_loaded.copy() 130 | peak_scores_done = True 131 | 132 | if not peak_scores_done: 133 | peak_ids, peak_pred_counts, peak_pred_profiles = fetch_peak_predictions(model, 134 | peaks, 135 | input_len, 136 | args.peak_genome, 137 | args.batch_size, 138 | debug_mode=args.debug_mode, 139 | lite=args.lite, 140 | forward_only=args.forward_only) 141 | assert np.array_equal(peaks["peak_id"].tolist(), peak_ids) 142 | peaks["peak_score"] = peak_pred_counts 143 | print() 144 | print(peaks.head()) 145 | print("Peak score table shape:", peaks.shape) 146 | print() 147 | peaks.to_csv(peak_scores_file, sep="\t", index=False) 148 | 149 | if len(shuf_variants_table) > 0 and not shuf_variants_done: 150 | shuf_logfc, shuf_jsd, \ 151 | shuf_allele1_quantile, shuf_allele2_quantile = get_variant_scores_with_peaks(shuf_allele1_pred_counts, 152 | shuf_allele2_pred_counts, 153 | shuf_allele1_pred_profiles, 154 | shuf_allele2_pred_profiles, 155 | np.array(peaks["peak_score"].tolist())) 156 | shuf_indel_idx, shuf_adjusted_jsd_list = adjust_indel_jsd(shuf_variants_table, 157 | shuf_allele1_pred_profiles, 158 | shuf_allele2_pred_profiles, 159 | shuf_jsd) 160 | shuf_has_indel_variants = (len(shuf_indel_idx) > 0) 161 | 162 | shuf_variants_table["logfc"] = shuf_logfc 163 | shuf_variants_table["abs_logfc"] = np.abs(shuf_logfc) 164 | if shuf_has_indel_variants: 165 | shuf_variants_table["jsd"] = shuf_adjusted_jsd_list 166 | else: 167 | shuf_variants_table["jsd"] = shuf_jsd 168 | assert np.array_equal(shuf_adjusted_jsd_list, shuf_jsd) 169 | shuf_variants_table['original_jsd'] = shuf_jsd 170 | shuf_variants_table["logfc_x_jsd"] = shuf_variants_table["logfc"] * shuf_variants_table["jsd"] 171 | shuf_variants_table["abs_logfc_x_jsd"] = shuf_variants_table["abs_logfc"] * shuf_variants_table["jsd"] 172 | 173 | shuf_variants_table["allele1_quantile"] = shuf_allele1_quantile 174 | shuf_variants_table["allele2_quantile"] = shuf_allele2_quantile 175 | shuf_variants_table["active_allele_quantile"] = shuf_variants_table[["allele1_quantile", "allele2_quantile"]].max(axis=1) 176 | shuf_variants_table["quantile_change"] = shuf_variants_table["allele2_quantile"] - shuf_variants_table["allele1_quantile"] 177 | shuf_variants_table["abs_quantile_change"] = np.abs(shuf_variants_table["quantile_change"]) 178 | shuf_variants_table["logfc_x_active_allele_quantile"] = shuf_variants_table["logfc"] * shuf_variants_table["active_allele_quantile"] 179 | shuf_variants_table["abs_logfc_x_active_allele_quantile"] = shuf_variants_table["abs_logfc"] * shuf_variants_table["active_allele_quantile"] 180 | shuf_variants_table["jsd_x_active_allele_quantile"] = shuf_variants_table["jsd"] * shuf_variants_table["active_allele_quantile"] 181 | shuf_variants_table["logfc_x_jsd_x_active_allele_quantile"] = shuf_variants_table["logfc_x_jsd"] * shuf_variants_table["active_allele_quantile"] 182 | shuf_variants_table["abs_logfc_x_jsd_x_active_allele_quantile"] = shuf_variants_table["abs_logfc_x_jsd"] * shuf_variants_table["active_allele_quantile"] 183 | 184 | assert shuf_variants_table["abs_logfc"].shape == shuf_logfc.shape 185 | assert shuf_variants_table["abs_logfc"].shape == shuf_jsd.shape 186 | assert shuf_variants_table["abs_logfc"].shape == shuf_variants_table["abs_logfc_x_jsd"].shape 187 | 188 | print() 189 | print(shuf_variants_table.head()) 190 | print("Shuffled score table shape:", shuf_variants_table.shape) 191 | print() 192 | shuf_variants_table.to_csv(shuf_scores_file, sep="\t", index=False) 193 | 194 | else: 195 | if len(shuf_variants_table) > 0 and not shuf_variants_done: 196 | shuf_logfc, shuf_jsd = get_variant_scores(shuf_allele1_pred_counts, 197 | shuf_allele2_pred_counts, 198 | shuf_allele1_pred_profiles, 199 | shuf_allele2_pred_profiles) 200 | 201 | shuf_indel_idx, shuf_adjusted_jsd_list = adjust_indel_jsd(shuf_variants_table, 202 | shuf_allele1_pred_profiles, 203 | shuf_allele2_pred_profiles, 204 | shuf_jsd) 205 | shuf_has_indel_variants = (len(shuf_indel_idx) > 0) 206 | 207 | shuf_variants_table["logfc"] = shuf_logfc 208 | shuf_variants_table["abs_logfc"] = np.abs(shuf_logfc) 209 | if shuf_has_indel_variants: 210 | shuf_variants_table["jsd"] = shuf_adjusted_jsd_list 211 | else: 212 | shuf_variants_table["jsd"] = shuf_jsd 213 | assert np.array_equal(shuf_adjusted_jsd_list, shuf_jsd) 214 | shuf_variants_table['original_jsd'] = shuf_jsd 215 | shuf_variants_table["logfc_x_jsd"] = shuf_variants_table["logfc"] * shuf_variants_table["jsd"] 216 | shuf_variants_table["abs_logfc_x_jsd"] = shuf_variants_table["abs_logfc"] * shuf_variants_table["jsd"] 217 | 218 | assert shuf_variants_table["abs_logfc"].shape == shuf_logfc.shape 219 | assert shuf_variants_table["abs_logfc"].shape == shuf_jsd.shape 220 | assert shuf_variants_table["abs_logfc"].shape == shuf_variants_table["abs_logfc_x_jsd"].shape 221 | 222 | print() 223 | print(shuf_variants_table.head()) 224 | print("Shuffled score table shape:", shuf_variants_table.shape) 225 | print() 226 | shuf_variants_table.to_csv(shuf_scores_file, sep="\t", index=False) 227 | 228 | if args.debug_mode: 229 | variants_table = variants_table.sample(10000, random_state=args.random_seed, ignore_index=True) 230 | print() 231 | print(variants_table.head()) 232 | print("Debug variants table shape:", variants_table.shape) 233 | print() 234 | 235 | # fetch model prediction for variants 236 | variant_ids, allele1_pred_counts, allele2_pred_counts, \ 237 | allele1_pred_profiles, allele2_pred_profiles = fetch_variant_predictions(model, 238 | variants_table, 239 | input_len, 240 | args.genome, 241 | args.batch_size, 242 | debug_mode=args.debug_mode, 243 | lite=args.lite, 244 | shuf=False, 245 | forward_only=args.forward_only) 246 | 247 | if args.peaks: 248 | logfc, jsd, \ 249 | allele1_quantile, allele2_quantile = get_variant_scores_with_peaks(allele1_pred_counts, 250 | allele2_pred_counts, 251 | allele1_pred_profiles, 252 | allele2_pred_profiles, 253 | np.array(peaks["peak_score"].tolist())) 254 | 255 | else: 256 | logfc, jsd = get_variant_scores(allele1_pred_counts, 257 | allele2_pred_counts, 258 | allele1_pred_profiles, 259 | allele2_pred_profiles) 260 | 261 | indel_idx, adjusted_jsd_list = adjust_indel_jsd(variants_table,allele1_pred_profiles,allele2_pred_profiles,jsd) 262 | has_indel_variants = (len(indel_idx) > 0) 263 | 264 | assert np.array_equal(variants_table["variant_id"].tolist(), variant_ids) 265 | variants_table["allele1_pred_counts"] = allele1_pred_counts 266 | variants_table["allele2_pred_counts"] = allele2_pred_counts 267 | variants_table["logfc"] = logfc 268 | variants_table["abs_logfc"] = np.abs(variants_table["logfc"]) 269 | if has_indel_variants: 270 | variants_table["jsd"] = adjusted_jsd_list 271 | else: 272 | variants_table["jsd"] = jsd 273 | assert np.array_equal(adjusted_jsd_list, jsd) 274 | variants_table["original_jsd"] = jsd 275 | variants_table["logfc_x_jsd"] = variants_table["logfc"] * variants_table["jsd"] 276 | variants_table["abs_logfc_x_jsd"] = variants_table["abs_logfc"] * variants_table["jsd"] 277 | 278 | if len(shuf_variants_table) > 0: 279 | variants_table["logfc.pval"] = get_pvals(variants_table["logfc"].tolist(), shuf_variants_table["logfc"], tail="both") 280 | variants_table["abs_logfc.pval"] = get_pvals(variants_table["abs_logfc"].tolist(), shuf_variants_table["abs_logfc"], tail="right") 281 | variants_table["jsd.pval"] = get_pvals(variants_table["jsd"].tolist(), shuf_variants_table["jsd"], tail="right") 282 | variants_table["logfc_x_jsd.pval"] = get_pvals(variants_table["logfc_x_jsd"].tolist(), shuf_variants_table["logfc_x_jsd"], tail="both") 283 | variants_table["abs_logfc_x_jsd.pval"] = get_pvals(variants_table["abs_logfc_x_jsd"].tolist(), shuf_variants_table["abs_logfc_x_jsd"], tail="right") 284 | if args.peaks: 285 | variants_table["allele1_quantile"] = allele1_quantile 286 | variants_table["allele2_quantile"] = allele2_quantile 287 | variants_table["active_allele_quantile"] = variants_table[["allele1_quantile", "allele2_quantile"]].max(axis=1) 288 | variants_table["quantile_change"] = variants_table["allele2_quantile"] - variants_table["allele1_quantile"] 289 | variants_table["abs_quantile_change"] = np.abs(variants_table["quantile_change"]) 290 | variants_table["logfc_x_active_allele_quantile"] = variants_table["logfc"] * variants_table["active_allele_quantile"] 291 | variants_table["abs_logfc_x_active_allele_quantile"] = variants_table["abs_logfc"] * variants_table["active_allele_quantile"] 292 | variants_table["jsd_x_active_allele_quantile"] = variants_table["jsd"] * variants_table["active_allele_quantile"] 293 | variants_table["logfc_x_jsd_x_active_allele_quantile"] = variants_table["logfc_x_jsd"] * variants_table["active_allele_quantile"] 294 | variants_table["abs_logfc_x_jsd_x_active_allele_quantile"] = variants_table["abs_logfc_x_jsd"] * variants_table["active_allele_quantile"] 295 | 296 | if len(shuf_variants_table) > 0: 297 | variants_table["active_allele_quantile.pval"] = get_pvals(variants_table["active_allele_quantile"].tolist(), 298 | shuf_variants_table["active_allele_quantile"], tail="right") 299 | variants_table['quantile_change.pval'] = get_pvals(variants_table["quantile_change"].tolist(), 300 | shuf_variants_table["quantile_change"], tail="both") 301 | variants_table["abs_quantile_change.pval"] = get_pvals(variants_table["abs_quantile_change"].tolist(), 302 | shuf_variants_table["abs_quantile_change"], tail="right") 303 | variants_table["logfc_x_active_allele_quantile.pval"] = get_pvals(variants_table["logfc_x_active_allele_quantile"].tolist(), 304 | shuf_variants_table["logfc_x_active_allele_quantile"], tail="both") 305 | variants_table["abs_logfc_x_active_allele_quantile.pval"] = get_pvals(variants_table["abs_logfc_x_active_allele_quantile"].tolist(), 306 | shuf_variants_table["abs_logfc_x_active_allele_quantile"], tail="right") 307 | variants_table["jsd_x_active_allele_quantile.pval"] = get_pvals(variants_table["jsd_x_active_allele_quantile"].tolist(), 308 | shuf_variants_table["jsd_x_active_allele_quantile"], tail="right") 309 | variants_table["logfc_x_jsd_x_active_allele_quantile.pval"] = get_pvals(variants_table["logfc_x_jsd_x_active_allele_quantile"].tolist(), 310 | shuf_variants_table["logfc_x_jsd_x_active_allele_quantile"], tail="both") 311 | variants_table["abs_logfc_x_jsd_x_active_allele_quantile.pval"] = get_pvals(variants_table["abs_logfc_x_jsd_x_active_allele_quantile"].tolist(), 312 | shuf_variants_table["abs_logfc_x_jsd_x_active_allele_quantile"], tail="right") 313 | 314 | if args.schema == "bed": 315 | variants_table['pos'] = variants_table['pos'] - 1 316 | 317 | # store predictions at variants 318 | if not args.no_hdf5: 319 | with h5py.File('.'.join([args.out_prefix, "variant_predictions.h5"]), 'w') as f: 320 | observed = f.create_group('observed') 321 | observed.create_dataset('allele1_pred_counts', data=allele1_pred_counts, compression='gzip', compression_opts=9) 322 | observed.create_dataset('allele2_pred_counts', data=allele2_pred_counts, compression='gzip', compression_opts=9) 323 | observed.create_dataset('allele1_pred_profiles', data=allele1_pred_profiles, compression='gzip', compression_opts=9) 324 | observed.create_dataset('allele2_pred_profiles', data=allele2_pred_profiles, compression='gzip', compression_opts=9) 325 | 326 | print() 327 | print(variants_table.head()) 328 | print("Output score table shape:", variants_table.shape) 329 | print() 330 | variants_table.to_csv('.'.join([args.out_prefix, "variant_scores.tsv"]), sep="\t", index=False) 331 | 332 | print("DONE") 333 | print() 334 | 335 | 336 | if __name__ == "__main__": 337 | main() 338 | -------------------------------------------------------------------------------- /src/variant_shap.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import argparse 5 | import scipy.stats 6 | from scipy.spatial.distance import jensenshannon 7 | from tensorflow.keras.utils import get_custom_objects 8 | from tensorflow.keras.models import load_model 9 | import tensorflow as tf 10 | import h5py 11 | import math 12 | from generators.variant_generator import VariantGenerator 13 | from generators.peak_generator import PeakGenerator 14 | from utils import argmanager, losses 15 | from utils.helpers import * 16 | from utils.io import * 17 | import shap 18 | from utils.shap_utils import * 19 | import deepdish as dd 20 | tf.compat.v1.disable_v2_behavior() 21 | 22 | 23 | def main(): 24 | args = argmanager.fetch_shap_args() 25 | print(args) 26 | 27 | out_dir = os.path.sep.join(args.out_prefix.split(os.path.sep)[:-1]) 28 | print() 29 | print('out_dir:', out_dir) 30 | print() 31 | if not os.path.exists(out_dir): 32 | raise OSError("Output directory does not exist") 33 | 34 | model = load_model_wrapper(args.model) 35 | variants_table = load_variant_table(args.list, args.schema) 36 | variants_table = variants_table.fillna('-') 37 | 38 | chrom_sizes = pd.read_csv(args.chrom_sizes, header=None, sep='\t', names=['chrom', 'size']) 39 | chrom_sizes_dict = chrom_sizes.set_index('chrom')['size'].to_dict() 40 | 41 | if args.debug_mode: 42 | variants_table = variants_table.sample(10) 43 | print(variants_table.head()) 44 | 45 | # infer input length 46 | if args.lite: 47 | input_len = model.input_shape[0][1] 48 | else: 49 | input_len = model.input_shape[1] 50 | print("input length inferred from the model: ", input_len) 51 | 52 | print(variants_table.shape) 53 | variants_table = variants_table.loc[variants_table.apply(lambda x: get_valid_variants(x.chr, x.pos, x.allele1, x.allele2, input_len, chrom_sizes_dict), axis=1)] 54 | variants_table.reset_index(drop=True, inplace=True) 55 | print(variants_table.shape) 56 | 57 | for shap_type in args.shap_type: 58 | # fetch model prediction for variants 59 | batch_size=args.batch_size 60 | ### set the batch size to the length of variant table in case variant table is small to avoid error 61 | batch_size=min(batch_size,len(variants_table)) 62 | # output_file=h5py.File(''.join([args.out_prefix, ".variant_shap.%s.h5"%shap_type]), 'w') 63 | # observed = output_file.create_group('observed') 64 | # allele1_write = observed.create_dataset('allele1_shap', (len(variants_table),2114,4), chunks=(batch_size,2114,4), dtype=np.float16, compression='gzip', compression_opts=9) 65 | # allele2_write = observed.create_dataset('allele2_shap', (len(variants_table),2114,4), chunks=(batch_size,2114,4), dtype=np.float16, compression='gzip', compression_opts=9) 66 | # variant_ids_write = observed.create_dataset('variant_ids', (len(variants_table),), chunks=(batch_size,), dtype='S100', compression='gzip', compression_opts=9) 67 | 68 | allele1_seqs = [] 69 | allele2_seqs = [] 70 | allele1_scores = [] 71 | allele2_scores = [] 72 | variant_ids = [] 73 | 74 | num_batches=len(variants_table)//batch_size 75 | for i in range(num_batches): 76 | sub_table=variants_table[i*batch_size:(i+1)*batch_size] 77 | var_ids, allele1_inputs, allele2_inputs, \ 78 | allele1_shap, allele2_shap = fetch_shap(model, 79 | sub_table, 80 | input_len, 81 | args.genome, 82 | args.batch_size, 83 | debug_mode=args.debug_mode, 84 | lite=args.lite, 85 | bias=None, 86 | shuf=False, 87 | shap_type=shap_type) 88 | 89 | # allele1_write[i*batch_size:(i+1)*batch_size] = allele1_shap 90 | # allele2_write[i*batch_size:(i+1)*batch_size] = allele2_shap 91 | # variant_ids_write[i*batch_size:(i+1)*batch_size] = [s.encode("utf-8") for s in var_ids] 92 | 93 | if len(variant_ids) == 0: 94 | allele1_seqs = allele1_inputs 95 | allele2_seqs = allele2_inputs 96 | allele1_scores = allele1_shap 97 | allele2_scores = allele2_shap 98 | variant_ids = var_ids 99 | else: 100 | allele1_seqs = np.concatenate((allele1_seqs, allele1_inputs)) 101 | allele2_seqs = np.concatenate((allele2_seqs, allele2_inputs)) 102 | allele1_scores = np.concatenate((allele1_scores, allele1_shap)) 103 | allele2_scores = np.concatenate((allele2_scores, allele2_shap)) 104 | variant_ids = np.concatenate((variant_ids, var_ids)) 105 | 106 | if len(variants_table)%batch_size != 0: 107 | sub_table=variants_table[num_batches*batch_size:len(variants_table)] 108 | var_ids, allele1_inputs, allele2_inputs, \ 109 | allele1_shap, allele2_shap = fetch_shap(model, 110 | sub_table, 111 | input_len, 112 | args.genome, 113 | args.batch_size, 114 | debug_mode=args.debug_mode, 115 | lite=args.lite, 116 | bias=None, 117 | shuf=False, 118 | shap_type=shap_type) 119 | 120 | # allele1_write[num_batches*batch_size:len(variants_table)] = allele1_shap 121 | # allele2_write[num_batches*batch_size:len(variants_table)] = allele2_shap 122 | # variant_ids_write[num_batches*batch_size:len(variants_table)] = [s.encode("utf-8") for s in var_ids] 123 | 124 | if len(variant_ids) == 0: 125 | allele1_seqs = allele1_inputs 126 | allele2_seqs = allele2_inputs 127 | allele1_scores = allele1_shap 128 | allele2_scores = allele2_shap 129 | variant_ids = var_ids 130 | else: 131 | allele1_seqs = np.concatenate((allele1_seqs, allele1_inputs)) 132 | allele2_seqs = np.concatenate((allele2_seqs, allele2_inputs)) 133 | allele1_scores = np.concatenate((allele1_scores, allele1_shap)) 134 | allele2_scores = np.concatenate((allele2_scores, allele2_shap)) 135 | variant_ids = np.concatenate((variant_ids, var_ids)) 136 | 137 | # # store shap at variants 138 | # with h5py.File(''.join([args.out_prefix, ".variant_shap.%s.h5"%shap_type]), 'w') as f: 139 | # observed = f.create_group('observed') 140 | # observed.create_dataset('allele1_shap', data=allele1_shap, compression='gzip', compression_opts=9) 141 | # observed.create_dataset('allele2_shap', data=allele2_shap, compression='gzip', compression_opts=9) 142 | 143 | assert(allele1_seqs.shape==allele1_scores.shape) 144 | assert(allele2_seqs.shape==allele2_scores.shape) 145 | assert(allele1_seqs.shape==allele2_seqs.shape) 146 | assert(allele1_scores.shape==allele2_scores.shape) 147 | assert(allele1_seqs.shape[2]==4) 148 | assert(len(allele1_seqs==len(variant_ids))) 149 | 150 | shap_dict = { 151 | 'raw': {'seq': np.concatenate((np.transpose(allele1_seqs, (0, 2, 1)).astype(np.int8), 152 | np.transpose(allele2_seqs, (0, 2, 1)).astype(np.int8)))}, 153 | 'shap': {'seq': np.concatenate((np.transpose(allele1_scores, (0, 2, 1)).astype(np.float16), 154 | np.transpose(allele2_scores, (0, 2, 1)).astype(np.float16)))}, 155 | 'projected_shap': {'seq': np.concatenate((np.transpose(allele1_seqs * allele1_scores, (0, 2, 1)).astype(np.float16), 156 | np.transpose(allele2_seqs * allele2_scores, (0, 2, 1)).astype(np.float16)))}, 157 | 'variant_ids': np.concatenate((np.array(variant_ids), np.array(variant_ids))), 158 | 'alleles': np.concatenate((np.array([0] * len(variant_ids)), 159 | np.array([1] * len(variant_ids))))} 160 | 161 | dd.io.save(''.join([args.out_prefix, ".variant_shap.%s.h5"%shap_type]), 162 | shap_dict, 163 | compression='blosc') 164 | 165 | print("DONE") 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /src/variant_summary_across_folds.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | from utils.argmanager import * 5 | from utils.io import * 6 | 7 | 8 | def geo_mean_overflow(iterable, axis=0): 9 | return np.exp(np.log(iterable).mean(axis=0)) 10 | 11 | 12 | def main(): 13 | args = fetch_variant_summary_args() 14 | print(args) 15 | variant_score_dir = args.score_dir 16 | variant_table_list = args.score_list 17 | output_prefix = args.out_prefix 18 | 19 | score_dict = {} 20 | for i in range(len(variant_table_list)): 21 | variant_score_file = os.path.join(variant_score_dir, variant_table_list[i]) 22 | 23 | if not os.path.isfile(variant_score_file): 24 | raise FileNotFoundError(f"Variant score file not found: {variant_score_file}") 25 | 26 | var_score = pd.read_table(variant_score_file) 27 | score_dict[i] = var_score 28 | 29 | variant_scores = score_dict[0][get_variant_schema(args.schema)].copy() 30 | for i in score_dict: 31 | assert score_dict[i]['chr'].tolist() == variant_scores['chr'].tolist() 32 | assert score_dict[i]['pos'].tolist() == variant_scores['pos'].tolist() 33 | assert score_dict[i]['allele1'].tolist() == variant_scores['allele1'].tolist() 34 | assert score_dict[i]['allele2'].tolist() == variant_scores['allele2'].tolist() 35 | assert score_dict[i]['variant_id'].tolist() == variant_scores['variant_id'].tolist() 36 | 37 | for score in ["logfc", "abs_logfc", "jsd", "logfc_x_jsd", "abs_logfc_x_jsd", "active_allele_quantile", 38 | "logfc_x_active_allele_quantile", "abs_logfc_x_active_allele_quantile", "jsd_x_active_allele_quantile", 39 | "logfc_x_jsd_x_active_allele_quantile", "abs_logfc_x_jsd_x_active_allele_quantile", 40 | "quantile_change", "abs_quantile_change"]: 41 | if score in score_dict[0]: 42 | variant_scores.loc[:, (score + '.mean')] = np.mean(np.array([score_dict[fold][score].tolist() 43 | for fold in score_dict]), axis=0) 44 | if score + '.pval' in score_dict[0]: 45 | variant_scores.loc[:, (score + '.mean' + '.pval')] = geo_mean_overflow([score_dict[fold][score + '.pval'].values for fold in score_dict]) 46 | elif score + '_pval' in score_dict[0]: 47 | variant_scores.loc[:, (score + '.mean' + '.pval')] = geo_mean_overflow([score_dict[fold][score + '_pval'].values for fold in score_dict]) 48 | 49 | print() 50 | print(variant_scores.head()) 51 | print("Summary score table shape:", variant_scores.shape) 52 | print() 53 | 54 | out_file = output_prefix + ".mean.variant_scores.tsv" 55 | variant_scores.to_csv(out_file,\ 56 | sep="\t",\ 57 | index=False) 58 | 59 | print("DONE") 60 | print() 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /tests/archive/test.annotations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python ../src/variant_annotation.py -p /oak/stanford/groups/akundaje/projects/chromatin-atlas-2022/ATAC/ENCSR611BQR/preprocessing/downloads/peaks.bed.gz -ge /oak/stanford/groups/akundaje/soumyak/refs/hg38/hg38.tss.bed -o annotations/test -sc chrombpnet -l /oak/stanford/groups/akundaje/airanman/nautilus-sync/gregor-luria/pvc/outputs/gregor-luria/variant_summary/ENCSR611BQR/.mean.variant_scores.tsv 4 | 5 | -------------------------------------------------------------------------------- /tests/archive/test.forward_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -u 5 | set -o pipefail 6 | set -x 7 | 8 | #rm -rf output 9 | mkdir -p output/ 10 | 11 | python -u ../src/variant_scoring.py \ 12 | -l /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/encode_variants.subset.tsv \ 13 | -g /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta \ 14 | -s /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_EBV.chrom.sizes.tsv \ 15 | -m /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/fold_0/chrombpnet_wo_bias.h5 \ 16 | -p /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/peaks.subset.bed.gz \ 17 | -fo \ 18 | -o output/forward_only \ 19 | -t 20 \ 20 | -sc chrombpnet 21 | 22 | -------------------------------------------------------------------------------- /tests/archive/test.per_chrom.forward_only.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -u 5 | set -o pipefail 6 | set -x 7 | 8 | #rm -rf output 9 | mkdir -p output/ 10 | 11 | python -u ../src/variant_scoring.per_chrom.py \ 12 | -l /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/encode_variants.subset.tsv \ 13 | -g /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta \ 14 | -s /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_EBV.chrom.sizes.tsv \ 15 | -m /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/fold_0/chrombpnet_wo_bias.h5 \ 16 | -p /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/peaks.subset.bed.gz \ 17 | -fo \ 18 | -o output/forward_only \ 19 | -t 20 \ 20 | -sc chrombpnet 21 | 22 | -------------------------------------------------------------------------------- /tests/archive/test.per_chrom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -u 5 | set -o pipefail 6 | set -x 7 | 8 | #rm -rf output 9 | mkdir -p output/ 10 | 11 | python -u ../src/variant_scoring.per_chrom.py \ 12 | -l /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/encode_variants.subset.tsv \ 13 | -g /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta \ 14 | -s /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_EBV.chrom.sizes.tsv \ 15 | -m /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/fold_0/chrombpnet_wo_bias.h5 \ 16 | -p /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/peaks.subset.bed.gz \ 17 | -o output/test \ 18 | -t 20 \ 19 | -sc chrombpnet 20 | 21 | -------------------------------------------------------------------------------- /tests/archive/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -u 5 | set -o pipefail 6 | set -x 7 | 8 | # rm -rf output 9 | mkdir -p output/ 10 | 11 | python -u ../src/variant_scoring.py \ 12 | -l /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/encode_variants.subset.tsv \ 13 | -g /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta \ 14 | -s /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/shared/GRCh38_EBV.chrom.sizes.tsv \ 15 | -m /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/fold_0/chrombpnet_wo_bias.h5 \ 16 | -p /oak/stanford/groups/akundaje/airanman/test_data/variant-scorer/ENCSR999NKW/peaks.subset.bed.gz \ 17 | -o output/test \ 18 | -t 20 \ 19 | -sc chrombpnet 20 | 21 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | 4 | 5 | @pytest.fixture(scope="session") 6 | def test_data_dir(): 7 | return os.path.join(os.path.dirname(__file__), 'data') 8 | 9 | 10 | @pytest.fixture(scope="session") 11 | def out_dir(): 12 | """Fixture to provide output directory for tests""" 13 | return os.path.join(os.path.dirname(__file__), 'outputs') 14 | 15 | 16 | @pytest.fixture(scope="session") 17 | def src(): 18 | return os.path.join(os.path.dirname(__file__), '..', 'src') 19 | 20 | 21 | @pytest.fixture(scope="session") 22 | def oak_path(): 23 | """Get OAK path from environment variable or skip test""" 24 | oak = os.path.join(os.environ.get('OAK'), "projects/variant-scorer-test") 25 | if oak is None: 26 | pytest.skip("OAK environment variable not set") 27 | if not os.path.exists(oak): 28 | pytest.skip(f"OAK file not found: {oak}") 29 | return oak 30 | 31 | 32 | @pytest.fixture(scope="session") 33 | def genome_path(oak_path): 34 | """Get genome path from environment variable or skip test""" 35 | genome = os.path.join(oak_path, "GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta") 36 | if genome is None: 37 | pytest.skip("Genome path not set") 38 | if not os.path.exists(genome): 39 | pytest.skip(f"Genome file not found: {genome}") 40 | return genome 41 | 42 | 43 | @pytest.fixture(scope="session") 44 | def model_paths(oak_path): 45 | """Get model path from environment variable or skip test""" 46 | model = [os.path.join(oak_path, f"GM12878_DNase_models/fold_{i}/model.chrombpnet_nobias.fold_{i}.ENCSR000EMT.h5") for i in range(5)] 47 | for m in model: 48 | if not os.path.exists(m): 49 | pytest.skip(f"Model file not found: {m}") 50 | return model 51 | 52 | 53 | @pytest.fixture(scope="session") 54 | def chrom_sizes_path(oak_path): 55 | """Get chrom sizes path from environment variable or skip test""" 56 | chrom_sizes = os.path.join(oak_path, "GRCh38_EBV_sorted_standard.chrom.sizes.tsv") 57 | if chrom_sizes is None: 58 | pytest.skip("Chrom sizes path not set") 59 | if not os.path.exists(chrom_sizes): 60 | pytest.skip(f"Chrom sizes file not found: {chrom_sizes}") 61 | return chrom_sizes 62 | -------------------------------------------------------------------------------- /tests/data/caqtls.african.lcls.benchmarking.subset.tsv: -------------------------------------------------------------------------------- 1 | var.chr var.pos_hg38 var.allele1 var.allele2 var.isused pred.chrombpnet.encsr000emt.variantscore.logfc pred.chrombpnet.encsr000emt.variantscore.jsd var.snp_id var.dbsnp_rsid 2 | chr1 976215 A G True 0.0821274204 0.0259986728889963 1_976215_A_G rs7417106 3 | chr1 1038800 T G True -0.0278944548 -0.0150936196154476 1_1038800_G_T rs3121577 4 | chr1 1038819 T C True 0.0343967458 0.0154224729757663 1_1038819_C_T rs2465131 5 | chr1 1038845 G A True -0.276904978 -0.0461506149788017 1_1038845_A_G rs2488995 6 | chr1 1038916 G A True -0.154193465 -0.026456420270438 1_1038916_A_G rs2488996 7 | chr1 3787477 C T True -0.338730288 -0.0509917614602056 1_3787477_C_T rs7527973 8 | chr2 465082 T G True 0.307680838 0.0338972498788086 2_465082_T_G rs2063455 9 | chr2 554536 G C True 0.000337756656 0.0159582500002122 2_554536_C_G rs2685245 10 | chr2 585333 G A True -0.0212269412 -0.0167773052685894 2_585333_G_A rs4402801 11 | chr2 727259 C A True 0.0791642759999999 0.0210299719658457 2_727259_C_A rs4854274 12 | -------------------------------------------------------------------------------- /tests/data/test.anno_input.tsv: -------------------------------------------------------------------------------- 1 | chr pos allele1 allele2 variant_id 2 | chr1 976215 A G rs7417106 3 | chr1 1038800 T G rs3121577 4 | chr1 1038819 T C rs2465131 5 | chr1 1038845 G A rs2488995 6 | chr1 1038916 G A rs2488996 7 | chr1 3787477 C T rs7527973 8 | chr2 465082 T G rs2063455 9 | chr2 554536 G C rs2685245 10 | chr2 585333 G A rs4402801 11 | chr2 727259 C A rs4854274 12 | -------------------------------------------------------------------------------- /tests/data/test.bed: -------------------------------------------------------------------------------- 1 | chr1 976214 976215 A G rs7417106 2 | chr1 1038799 1038800 T G rs3121577 3 | chr1 1038818 1038819 T C rs2465131 4 | chr1 1038844 1038845 G A rs2488995 5 | chr1 1038915 1038916 G A rs2488996 6 | chr1 3787476 3787477 C T rs7527973 7 | chr2 465081 465082 T G rs2063455 8 | chr2 554535 554536 G C rs2685245 9 | chr2 585332 585333 G A rs4402801 10 | chr2 727258 727259 C A rs4854274 11 | -------------------------------------------------------------------------------- /tests/data/test.chrombpnet.incorrect.tsv: -------------------------------------------------------------------------------- 1 | chr1 959193 G A rs13303010 2 | chr1 959339 C T rs187243360 3 | chr1 960509 A T rs72891151 4 | chr1 960684 C N rs113034360 5 | chr1 866281 X T 1_866281_X_T 6 | chr2 1234567 C Y 2_1234567_C_Y 7 | -------------------------------------------------------------------------------- /tests/data/test.chrombpnet.incorrect2.tsv: -------------------------------------------------------------------------------- 1 | chr1 959193 G A rs13303010 2 | chr1 959339 C T rs187243360 3 | chr1 960509 A T rs72891151 4 | chr1 960684 C A rs113034360 5 | chr1 866281 C TG 1_866281_C_TG 6 | chr2 1234567 C -- 2_1234567_Cdel 7 | -------------------------------------------------------------------------------- /tests/data/test.chrombpnet.no_chr.tsv: -------------------------------------------------------------------------------- 1 | 1 976215 A G rs7417106 2 | 1 1038800 T G rs3121577 3 | 1 1038819 T C rs2465131 4 | 1 1038845 G A rs2488995 5 | 1 1038916 G A rs2488996 6 | 1 3787477 C T rs7527973 7 | 2 465082 T G rs2063455 8 | 2 554536 G C rs2685245 9 | 2 585333 G A rs4402801 10 | 2 727259 C A rs4854274 11 | -------------------------------------------------------------------------------- /tests/data/test.chrombpnet.tsv: -------------------------------------------------------------------------------- 1 | chr1 976215 A G rs7417106 2 | chr1 1038800 T G rs3121577 3 | chr1 1038819 T C rs2465131 4 | chr1 1038845 G A rs2488995 5 | chr1 1038916 G A rs2488996 6 | chr1 3787477 C T rs7527973 7 | chr2 465082 T G rs2063455 8 | chr2 554536 G C rs2685245 9 | chr2 585333 G A rs4402801 10 | chr2 727259 C A rs4854274 11 | -------------------------------------------------------------------------------- /tests/data/test.genes.bed: -------------------------------------------------------------------------------- 1 | chr1 950000 960000 gene_A 0 + 2 | chr1 966000 970000 gene_B 0 - 3 | chr1 1039916 1049916 gene_C 0 + 4 | chr1 4787476 4788476 gene_D 0 + 5 | chr2 465000 466000 gene_G 0 - 6 | chr2 655332 665333 gene_H 0 - 7 | chr2 907258 907259 gene_I 0 + -------------------------------------------------------------------------------- /tests/data/test.hits.bed: -------------------------------------------------------------------------------- 1 | chr1 976210 976216 motif_aa 0 + pos 2 | chr1 1038815 1038820 motif_cc 0 + pos 3 | chr1 1038843 1038848 motif_dd 0 + neg 4 | chr1 1038840 1038846 motif_dd 0 + neg 5 | chr2 554530 554537 motif_cc 0 - neg 6 | chr2 554528 554539 motif_gg 0 - pos 7 | chr2 727250 727259 motif_hh 0 + pos -------------------------------------------------------------------------------- /tests/data/test.incorrect.bed: -------------------------------------------------------------------------------- 1 | chr1 959192 G A rs13303010 2 | chr1 959338 C T rs187243360 3 | chr1 960508 A T rs72891151 4 | chr1 960683 C G rs113034360 5 | chr2 11555 G A rs79388895 6 | chr2 11606 T C rs73138516 7 | -------------------------------------------------------------------------------- /tests/data/test.original.tsv: -------------------------------------------------------------------------------- 1 | chr1 976215 rs7417106 A G 2 | chr1 1038800 rs3121577 T G 3 | chr1 1038819 rs2465131 T C 4 | chr1 1038845 rs2488995 G A 5 | chr1 1038916 rs2488996 G A 6 | chr1 3787477 rs7527973 C T 7 | chr2 465082 rs2063455 T G 8 | chr2 554536 rs2685245 G C 9 | chr2 585333 rs4402801 G A 10 | chr2 727259 rs4854274 C A 11 | -------------------------------------------------------------------------------- /tests/data/test.peaks.bed: -------------------------------------------------------------------------------- 1 | chr1 975814 976814 peak_1 0 0 0 0 0 500 2 | chr1 1036000 1037000 peak_2 0 0 0 0 0 500 3 | chr1 3787400 3788400 peak_3 0 0 0 0 0 500 4 | chr2 554500 555500 peak_4 0 0 0 0 0 500 5 | chr2 727200 728200 peak_5 0 0 0 0 0 500 -------------------------------------------------------------------------------- /tests/data/test.plink.tsv: -------------------------------------------------------------------------------- 1 | chr1 rs7417106 0 976215 A G 2 | chr1 rs3121577 0 1038800 T G 3 | chr1 rs2465131 0 1038819 T C 4 | chr1 rs2488995 0 1038845 G A 5 | chr1 rs2488996 0 1038916 G A 6 | chr1 rs7527973 0 3787477 C T 7 | chr2 rs2063455 0 465082 T G 8 | chr2 rs2685245 0 554536 G C 9 | chr2 rs4402801 0 585333 G A 10 | chr2 rs4854274 0 727259 C A 11 | -------------------------------------------------------------------------------- /tests/test_load_bed_files.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import pybedtools 4 | import os 5 | import sys 6 | 7 | # Add src to path to import modules 8 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 9 | from utils.io import load_genes, load_peaks 10 | 11 | class TestLoadBedFiles: 12 | 13 | def test_load_genes(self, test_data_dir): 14 | """Test loading genes file""" 15 | file_path = os.path.join(test_data_dir, 'test.genes.bed') 16 | gene_bed = load_genes(file_path) 17 | 18 | # Check that it returns a BedTool object 19 | assert isinstance(gene_bed, pybedtools.BedTool) 20 | 21 | # Check that it has data 22 | assert len(gene_bed) > 0 23 | 24 | # Check basic BED format (at least 3 columns) 25 | df = gene_bed.to_dataframe() 26 | assert df.shape[1] >= 3 27 | 28 | def test_load_peaks(self, test_data_dir): 29 | """Test loading peaks file""" 30 | file_path = os.path.join(test_data_dir, 'test.peaks.bed') 31 | peak_bed = load_peaks(file_path) 32 | 33 | # Check that it returns a BedTool object 34 | assert isinstance(peak_bed, pybedtools.BedTool) 35 | 36 | # Check that it has data 37 | assert len(peak_bed) > 0 38 | 39 | # Check basic BED format (at least 3 columns) 40 | df = peak_bed.to_dataframe() 41 | assert df.shape[1] >= 3 42 | 43 | 44 | def test_file_not_found_genes(self): 45 | """Test that missing genes file raises appropriate error""" 46 | with pytest.raises(FileNotFoundError): 47 | load_genes('nonexistent_genes.bed') 48 | 49 | def test_file_not_found_peaks(self): 50 | """Test that missing peaks file raises appropriate error""" 51 | with pytest.raises(FileNotFoundError): 52 | load_peaks('nonexistent_peaks.bed') 53 | -------------------------------------------------------------------------------- /tests/test_load_variant_table.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import os 4 | import sys 5 | 6 | # Add src to path to import modules 7 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 8 | from utils.io import load_variant_table, get_variant_schema 9 | 10 | class TestLoadVariantTable: 11 | 12 | def test_load_chrombpnet_schema(self, test_data_dir): 13 | """Test loading variants with chrombpnet schema""" 14 | file_path = os.path.join(test_data_dir, 'test.chrombpnet.tsv') 15 | df = load_variant_table(file_path, 'chrombpnet') 16 | 17 | # Check columns 18 | expected_cols = ['chr', 'pos', 'allele1', 'allele2', 'variant_id'] 19 | assert list(df.columns) == expected_cols 20 | 21 | # Check data types 22 | assert df['chr'].dtype == 'object' 23 | assert df['pos'].dtype in ['int64', 'int32'] 24 | 25 | # Check chromosome prefixes are preserved/added 26 | assert all(df['chr'].str.startswith('chr')) 27 | 28 | # Check basic structure 29 | assert len(df) > 0 30 | assert df['allele1'].notna().all() 31 | assert df['allele2'].notna().all() 32 | 33 | def test_incorrect_chrombpnet(self, test_data_dir): 34 | """Test loading an incorrect chrombpnet file""" 35 | file_path = os.path.join(test_data_dir, 'test.chrombpnet.incorrect.tsv') 36 | 37 | with pytest.raises(ValueError): 38 | load_variant_table(file_path, 'chrombpnet') 39 | 40 | def test_load_bed_schema(self, test_data_dir): 41 | """Test loading variants with bed schema""" 42 | file_path = os.path.join(test_data_dir, 'test.bed') 43 | 44 | # First check that the file has the right number of columns 45 | df_orig = pd.read_csv(file_path, sep='\t', header=None) 46 | expected_bed_cols = 6 47 | assert df_orig.shape[1] == expected_bed_cols, f"BED file should have {expected_bed_cols} columns, found {df_orig.shape[1]}" 48 | 49 | df = load_variant_table(file_path, 'bed') 50 | 51 | # Check columns 52 | expected_cols = ['chr', 'pos', 'end', 'allele1', 'allele2', 'variant_id'] 53 | assert list(df.columns) == expected_cols 54 | 55 | # Check that no columns have NaN values 56 | assert not df.isnull().any().any(), "BED file has missing values" 57 | 58 | # Check that the position column is incremented by 1 59 | assert (df['pos'] - 1).equals(df_orig[1]) 60 | 61 | def test_incorrect_bed(self, test_data_dir): 62 | """Test loading an incorrect bed file""" 63 | file_path = os.path.join(test_data_dir, 'test.incorrect.bed') 64 | 65 | with pytest.raises(ValueError): 66 | load_variant_table(file_path, 'bed') 67 | 68 | def test_load_plink_schema(self, test_data_dir): 69 | """Test loading variants with plink schema""" 70 | file_path = os.path.join(test_data_dir, 'test.plink.tsv') 71 | df = load_variant_table(file_path, 'plink') 72 | 73 | # Check remaining columns 74 | expected_cols = ['chr', 'variant_id', 'pos', 'allele1', 'allele2'] 75 | assert list(df.columns) == expected_cols 76 | 77 | def test_chromosome_prefix_addition(self, test_data_dir): 78 | """Test that chr prefix is added when missing""" 79 | # Look for files without chr prefix 80 | 81 | file_path = os.path.join(test_data_dir, 'test.chrombpnet.no_chr.tsv') 82 | df = load_variant_table(file_path, 'chrombpnet') 83 | 84 | # Check that chr prefix was added 85 | assert all(df['chr'].str.startswith('chr')) 86 | 87 | def test_invalid_alleles(self, test_data_dir): 88 | """Test that files with invalid allele characters are rejected""" 89 | file_path = os.path.join(test_data_dir, 'test.chrombpnet.incorrect.tsv') 90 | 91 | with pytest.raises(ValueError, match="Invalid characters"): 92 | load_variant_table(file_path, 'chrombpnet') 93 | 94 | def test_invalid_alleles2(self, test_data_dir): 95 | """Test that files with invalid allele characters are rejected""" 96 | file_path = os.path.join(test_data_dir, 'test.chrombpnet.incorrect2.tsv') 97 | 98 | with pytest.raises(ValueError, match="Invalid allele"): 99 | load_variant_table(file_path, 'chrombpnet') 100 | 101 | def test_invalid_schema(self, test_data_dir): 102 | """Test that invalid schema raises appropriate error""" 103 | file_path = os.path.join(test_data_dir, 'test.chrombpnet.tsv') 104 | 105 | with pytest.raises(KeyError): 106 | load_variant_table(file_path, 'invalid_schema') 107 | 108 | def test_file_not_found(self): 109 | """Test that missing file raises appropriate error""" 110 | with pytest.raises(FileNotFoundError): 111 | load_variant_table('nonexistent_file.tsv', 'chrombpnet') 112 | 113 | def test_get_variant_schema(self): 114 | """Test the get_variant_schema helper function""" 115 | # Test all known schemas 116 | schemas = ['original', 'plink', 'bed', 'chrombpnet'] 117 | 118 | for schema in schemas: 119 | columns = get_variant_schema(schema) 120 | assert isinstance(columns, list) 121 | assert len(columns) > 0 122 | assert 'chr' in columns 123 | assert 'pos' in columns 124 | 125 | # Test invalid schema 126 | with pytest.raises(KeyError): 127 | get_variant_schema('invalid_schema') 128 | -------------------------------------------------------------------------------- /tests/test_one_hot.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | import os 4 | import sys 5 | 6 | # Add src to path to import modules 7 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src')) 8 | from utils.one_hot import dna_to_one_hot, one_hot_to_dna 9 | 10 | 11 | class TestOneHotConversion: 12 | 13 | def test_dna_to_one_hot_simple(self): 14 | """Test basic DNA to one-hot conversion""" 15 | seqs = ["ACGT"] 16 | result = dna_to_one_hot(seqs) 17 | 18 | # Expected one-hot encoding for "ACGT" (alphabetical order: A=0, C=1, G=2, T=3) 19 | expected = np.array([[[1, 0, 0, 0], # A 20 | [0, 1, 0, 0], # C 21 | [0, 0, 1, 0], # G 22 | [0, 0, 0, 1]]], dtype=np.int8) # T 23 | 24 | assert result.shape == (1, 4, 4) 25 | np.testing.assert_array_equal(result, expected) 26 | 27 | def test_dna_to_one_hot_multiple_sequences(self): 28 | """Test conversion of multiple sequences""" 29 | seqs = ["ATCG", "GCTA", "AAAA"] 30 | result = dna_to_one_hot(seqs) 31 | 32 | assert result.shape == (3, 4, 4) 33 | 34 | # Check first sequence "ATCG" 35 | expected_first = [[1, 0, 0, 0], # A 36 | [0, 0, 0, 1], # T 37 | [0, 1, 0, 0], # C 38 | [0, 0, 1, 0]] # G 39 | np.testing.assert_array_equal(result[0], expected_first) 40 | 41 | # Check second sequence "GCTA" 42 | expected_second = [[0, 0, 1, 0], # G 43 | [0, 1, 0, 0], # C 44 | [0, 0, 0, 1], # T 45 | [1, 0, 0, 0]] 46 | np.testing.assert_array_equal(result[1], expected_second) 47 | 48 | # Check third sequence "AAAA" 49 | expected_third = [[1, 0, 0, 0], # A 50 | [1, 0, 0, 0], # A 51 | [1, 0, 0, 0], # A 52 | [1, 0, 0, 0]] # A 53 | np.testing.assert_array_equal(result[2], expected_third) 54 | 55 | def test_dna_to_one_hot_lowercase(self): 56 | """Test that lowercase sequences are converted to uppercase""" 57 | seqs = ["atcg"] 58 | result = dna_to_one_hot(seqs) 59 | 60 | expected = np.array([[[1, 0, 0, 0], # A 61 | [0, 0, 0, 1], # T 62 | [0, 1, 0, 0], # C 63 | [0, 0, 1, 0]]], dtype=np.int8) # G 64 | 65 | np.testing.assert_array_equal(result, expected) 66 | 67 | def test_dna_to_one_hot_invalid_bases(self): 68 | """Test that invalid bases get all-zero encoding""" 69 | seqs = ["ANCG"] # N is not a valid base 70 | result = dna_to_one_hot(seqs) 71 | 72 | expected = np.array([[[1, 0, 0, 0], # A 73 | [0, 0, 0, 0], # N -> all zeros 74 | [0, 1, 0, 0], # C 75 | [0, 0, 1, 0]]], dtype=np.int8) # G 76 | 77 | np.testing.assert_array_equal(result, expected) 78 | 79 | def test_one_hot_to_dna_simple(self): 80 | """Test basic one-hot to DNA conversion""" 81 | one_hot = np.array([[[1, 0, 0, 0], # A 82 | [0, 1, 0, 0], # C 83 | [0, 0, 1, 0], # G 84 | [0, 0, 0, 1]]], dtype=np.int8) # T 85 | 86 | result = one_hot_to_dna(one_hot) 87 | expected = ["ACGT"] 88 | 89 | assert result == expected 90 | 91 | def test_one_hot_to_dna_multiple_sequences(self): 92 | """Test conversion of multiple one-hot sequences""" 93 | one_hot = np.array([[[1, 0, 0, 0], # A 94 | [0, 0, 0, 1], # T 95 | [0, 1, 0, 0], # C 96 | [0, 0, 1, 0]], # G 97 | [[0, 0, 1, 0], # G 98 | [0, 1, 0, 0], # C 99 | [0, 0, 0, 1], # T 100 | [1, 0, 0, 0]]], dtype=np.int8) # A 101 | 102 | result = one_hot_to_dna(one_hot) 103 | expected = ["ATCG", "GCTA"] 104 | 105 | assert result == expected 106 | 107 | def test_one_hot_to_dna_all_zeros(self): 108 | """Test that all-zero encodings convert to N""" 109 | one_hot = np.array([[[1, 0, 0, 0], # A 110 | [0, 0, 0, 0], # all zeros -> N 111 | [0, 1, 0, 0], # C 112 | [0, 0, 1, 0]]], dtype=np.int8) # G 113 | 114 | result = one_hot_to_dna(one_hot) 115 | expected = ["ANCG"] 116 | 117 | assert result == expected 118 | 119 | def test_roundtrip_conversion(self): 120 | """Test that DNA -> one-hot -> DNA is consistent""" 121 | original_seqs = ["ATCG", "GCTA", "AAAA", "TTTT"] 122 | 123 | # Convert to one-hot and back 124 | one_hot = dna_to_one_hot(original_seqs) 125 | recovered_seqs = one_hot_to_dna(one_hot) 126 | 127 | assert recovered_seqs == original_seqs 128 | 129 | def test_roundtrip_with_invalid_bases(self): 130 | """Test roundtrip with invalid bases (should become N)""" 131 | original_seqs = ["ANCG", "GCTX"] 132 | expected_seqs = ["ANCG", "GCTN"] # X becomes N 133 | 134 | # Convert to one-hot and back 135 | one_hot = dna_to_one_hot(original_seqs) 136 | recovered_seqs = one_hot_to_dna(one_hot) 137 | 138 | assert recovered_seqs == expected_seqs 139 | 140 | def test_equal_length_requirement(self): 141 | """Test that sequences must be equal length""" 142 | seqs = ["ATCG", "GC"] # Different lengths 143 | 144 | with pytest.raises(AssertionError): 145 | dna_to_one_hot(seqs) 146 | 147 | def test_empty_sequence(self): 148 | """Test handling of empty sequences""" 149 | seqs = [""] 150 | result = dna_to_one_hot(seqs) 151 | 152 | assert result.shape == (1, 0, 4) 153 | 154 | # Test roundtrip 155 | recovered = one_hot_to_dna(result) 156 | assert recovered == [""] 157 | 158 | def test_data_types(self): 159 | """Test that output data types are correct""" 160 | seqs = ["ATCG"] 161 | one_hot = dna_to_one_hot(seqs) 162 | 163 | # Should be int8 164 | assert one_hot.dtype == np.int8 165 | 166 | # Should contain only 0s and 1s 167 | assert np.all(np.isin(one_hot, [0, 1])) 168 | 169 | # Test one_hot_to_dna returns strings 170 | dna_seqs = one_hot_to_dna(one_hot) 171 | assert isinstance(dna_seqs, list) 172 | assert all(isinstance(seq, str) for seq in dna_seqs) 173 | -------------------------------------------------------------------------------- /tests/test_variant_annotation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import subprocess 3 | import os 4 | import sys 5 | import pandas as pd 6 | from pathlib import Path 7 | 8 | class TestVariantAnnotationCLI: 9 | 10 | @pytest.fixture(scope="class") 11 | def script_path(self, src): 12 | """Fixture to provide the path to the variant_annotation.py script""" 13 | return os.path.join(src, "variant_annotation.py") 14 | 15 | 16 | def test_arg_validation(self, test_data_dir, script_path, out_dir): 17 | """Test that an error is raised if at least one of genes, peaks, or hits is not provided.""" 18 | 19 | test_variants = os.path.join(test_data_dir, 'test.bed') 20 | if not os.path.exists(test_variants): 21 | pytest.skip("Test variant file not found") 22 | 23 | output_prefix = os.path.join(out_dir, f"scores") 24 | 25 | cmd = [ 26 | sys.executable, script_path, 27 | '--list', test_variants, 28 | '--schema', 'bed', 29 | '--out_prefix', output_prefix 30 | ] 31 | 32 | # Run the command and check for error 33 | result = subprocess.run(cmd, capture_output=True, text=True) 34 | 35 | assert result.returncode != 0, "Command should fail without peaks, hits, or genes" 36 | 37 | # Check stdout instead of stderr since that's where the error message appears 38 | assert "at least one of" in result.stdout.lower(), \ 39 | f"Expected error message not found in stdout. Got: {result.stdout}" 40 | 41 | 42 | def test_variant_annotation(self, test_data_dir, script_path, out_dir): 43 | """Test the variant annotation script with valid inputs.""" 44 | 45 | test_variants = os.path.join(test_data_dir, 'test.anno_input.tsv') 46 | if not os.path.exists(test_variants): 47 | pytest.skip("Test variant file not found") 48 | 49 | output_prefix = os.path.join(out_dir, f"scores") 50 | 51 | cmd = [ 52 | sys.executable, script_path, 53 | '--list', test_variants, 54 | '--out_prefix', output_prefix, 55 | '--peaks', os.path.join(test_data_dir, 'test.peaks.bed'), 56 | '--hits', os.path.join(test_data_dir, 'test.hits.bed'), 57 | '--genes', os.path.join(test_data_dir, 'test.genes.bed') 58 | ] 59 | 60 | result = subprocess.run(cmd, capture_output=True, text=True) 61 | 62 | assert result.returncode == 0, f"Command failed with error: {result.stderr}" 63 | 64 | output_file = f"{output_prefix}.annotations.tsv" 65 | assert os.path.exists(output_file), "Output file was not created" 66 | 67 | df = pd.read_csv(output_file, sep='\t') 68 | 69 | assert not df.empty, "Output DataFrame is empty" 70 | 71 | # Test that the hit overlaps are correct: 72 | expected_motifs = ['motif_aa', '-', 'motif_cc', 'motif_dd', '-', '-', '-', 'motif_gg,motif_cc', '-', 'motif_hh'] 73 | expected_motifs2 = ['motif_aa', '-', 'motif_cc', 'motif_dd', '-', '-', '-', 'motif_cc,motif_gg', '-', 'motif_hh'] 74 | assert df['hits_motifs'].tolist() == expected_motifs or df['hits_motifs'].tolist() == expected_motifs2, \ 75 | f"Expected hits motifs {expected_motifs} but got {df['hits_motifs'].tolist()}" 76 | 77 | expected_peak_overlap = [True, False, False, False, False, True, False, True, False, True] 78 | assert df['peak_overlap'].tolist() == expected_peak_overlap, \ 79 | f"Expected peak overlap {expected_peak_overlap} but got {df['peak_overlap'].tolist()}" 80 | 81 | expected_closest_genes = ['gene_B', 'gene_C', 'gene_C', 'gene_C', 'gene_C', 'gene_D', 'gene_G', 'gene_G', 'gene_H', 'gene_H'] 82 | assert df['closest_gene_1'].tolist() == expected_closest_genes, \ 83 | f"Expected closest genes {expected_closest_genes} but got {df['closest_gene_1'].tolist()}" 84 | -------------------------------------------------------------------------------- /tests/test_variant_scoring.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import subprocess 3 | import tempfile 4 | import os 5 | import sys 6 | import pandas as pd 7 | from pathlib import Path 8 | 9 | class TestVariantScoringCLI: 10 | 11 | @pytest.fixture(scope="class") 12 | def script_path(self, src): 13 | """Fixture to provide the path to the variant_scoring.py script""" 14 | return os.path.join(src, "variant_scoring.py") 15 | 16 | @pytest.fixture(scope="class") 17 | def script_path_per_chrom(self, src): 18 | """Fixture to provide the path to the variant_scoring.per_chrom.py script""" 19 | return os.path.join(src, "variant_scoring.per_chrom.py") 20 | 21 | @pytest.mark.oak 22 | def test_variant_scoring_help(self, script_path): 23 | """Test that variant_scoring.py shows help without errors""" 24 | if not os.path.exists(script_path): 25 | pytest.skip(f"Script {script_path} not found") 26 | 27 | cmd = [sys.executable, script_path, '--help'] 28 | result = subprocess.run(cmd, capture_output=True, text=True) 29 | 30 | # Should exit successfully and show help 31 | assert result.returncode == 0 32 | assert 'usage:' in result.stdout.lower() or 'help' in result.stdout.lower() 33 | # Check that required arguments are mentioned 34 | assert '--list' in result.stdout 35 | assert '--genome' in result.stdout 36 | assert '--model' in result.stdout 37 | assert '--out_prefix' in result.stdout 38 | assert '--chrom_sizes' in result.stdout 39 | 40 | @pytest.mark.oak 41 | def test_variant_scoring_missing_required_args(self, script_path): 42 | """Test that variant_scoring.py fails gracefully with missing required arguments""" 43 | if not os.path.exists(script_path): 44 | pytest.skip(f"Script {script_path} not found") 45 | 46 | cmd = [sys.executable, script_path] 47 | result = subprocess.run(cmd, capture_output=True, text=True) 48 | 49 | # Should fail with non-zero exit code 50 | assert result.returncode != 0 51 | 52 | # Should mention missing required arguments 53 | error_text = result.stderr.lower() 54 | assert 'required' in error_text or 'argument' in error_text or 'missing' in error_text 55 | 56 | @pytest.mark.gpu 57 | @pytest.mark.oak 58 | def test_variant_scoring_no_peaks(self, out_dir, script_path, test_data_dir, genome_path, model_paths, chrom_sizes_path): 59 | """Test variant_scoring.py with real genome/model data (requires env vars and GPU)""" 60 | if not os.path.exists(script_path): 61 | pytest.skip(f"Script {script_path} not found") 62 | 63 | test_variants = os.path.join(test_data_dir, 'test.chrombpnet.tsv') 64 | if not os.path.exists(test_variants): 65 | pytest.skip("Test variant file not found") 66 | 67 | # Run for each fold 68 | for fold in range(5): 69 | model_path = model_paths[fold] 70 | output_prefix = os.path.join(out_dir, f"fold_{fold}") 71 | 72 | cmd = [ 73 | sys.executable, script_path, 74 | '--list', test_variants, 75 | '--genome', genome_path, 76 | '--model', model_path, 77 | '--out_prefix', output_prefix, 78 | '--chrom_sizes', chrom_sizes_path, 79 | '--num_shuf', '2', # Use a small number for testing 80 | '--schema', 'chrombpnet', 81 | '--no_hdf5' # Skip HDF5 output for faster testing 82 | ] 83 | 84 | result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) 85 | 86 | if result.returncode != 0: 87 | print(f"STDOUT: {result.stdout}") 88 | print(f"STDERR: {result.stderr}") 89 | 90 | # Check if it completed successfully 91 | assert result.returncode == 0, f"Script failed: {result.stderr}" 92 | 93 | # Check output file exists 94 | output_file = f"{output_prefix}.variant_scores.tsv" 95 | assert os.path.exists(output_file), "Output file not created" 96 | 97 | # Basic validation of output 98 | df = pd.read_csv(output_file, sep='\t') 99 | assert len(df) > 0, "Output file is empty" 100 | assert 'logfc' in df.columns, "Missing logfc column" 101 | assert 'jsd' in df.columns, "Missing jsd column" 102 | 103 | # Validate that we have the expected number of variants 104 | input_df = pd.read_csv(test_variants, sep='\t', header=None) 105 | assert len(df) == len(input_df), "Output has different number of variants than input" 106 | 107 | @pytest.mark.oak 108 | def test_variant_summary_across_folds(self, out_dir, script_path, test_data_dir): 109 | """Test variant_summary_across_folds.py (depends on scoring test)""" 110 | # This test depends on the scoring test having run successfully 111 | # Check that fold output files exist 112 | for fold in range(5): 113 | output_file = os.path.join(out_dir, f"fold_{fold}.variant_scores.tsv") 114 | if not os.path.exists(output_file): 115 | pytest.skip("Scoring test outputs not found. Requires test_variant_scoring_no_peaks.") 116 | 117 | # Run summary script 118 | summary_script = os.path.join(os.path.dirname(script_path), 'variant_summary_across_folds.py') 119 | if not os.path.exists(summary_script): 120 | pytest.skip(f"Summary script {summary_script} not found") 121 | 122 | summary_cmd = [ 123 | sys.executable, summary_script, 124 | '--score_dir', out_dir, 125 | '--score_list' 126 | ] 127 | # Add each file as a separate argument 128 | for fold in range(5): 129 | summary_cmd.append(f"fold_{fold}.variant_scores.tsv") 130 | 131 | summary_cmd.extend([ 132 | '--out_prefix', os.path.join(out_dir, 'summary'), 133 | '--schema', 'chrombpnet' 134 | ]) 135 | 136 | result = subprocess.run(summary_cmd, capture_output=True, text=True) 137 | assert result.returncode == 0, f"Summary script failed: {result.stderr}" 138 | 139 | # Check output file exists 140 | summary_file = os.path.join(out_dir, 'summary.mean.variant_scores.tsv') 141 | assert os.path.exists(summary_file), "Summary output file not created" 142 | 143 | @pytest.mark.oak 144 | def test_variant_scoring_accuracy(self, out_dir, test_data_dir): 145 | """Test variant scoring accuracy against known caQTLs (depends on summary test)""" 146 | # Check that summary output exists 147 | summary_file = os.path.join(out_dir, 'summary.mean.variant_scores.tsv') 148 | if not os.path.exists(summary_file): 149 | pytest.skip("Summary test output not found. Requires test_variant_summary_across_folds.") 150 | 151 | # Load CaQTL reference data 152 | caqtl_file = os.path.join(test_data_dir, 'caqtls.african.lcls.benchmarking.subset.tsv') 153 | if not os.path.exists(caqtl_file): 154 | pytest.skip("CaQTL reference file not found") 155 | 156 | # Load scoring results 157 | scores_df = pd.read_csv(summary_file, sep='\t') 158 | caqtl_df = pd.read_csv(caqtl_file, sep='\t') 159 | 160 | # Basic validation 161 | assert len(scores_df) > 0, "No scoring results found" 162 | assert len(caqtl_df) > 0, "No CaQTL data found" 163 | 164 | # Check that we have the expected columns 165 | assert 'logfc.mean' in scores_df.columns, "Missing logfc.mean column in scores" 166 | assert 'pred.chrombpnet.encsr000emt.variantscore.logfc' in caqtl_df.columns, "Missing logfc column in ground truth" 167 | 168 | # Merge datasets on variant identifier 169 | if 'variant_id' in scores_df.columns and 'var.dbsnp_rsid' in caqtl_df.columns: 170 | merged_df = pd.merge(scores_df, caqtl_df, left_on='variant_id', right_on='var.dbsnp_rsid', how='inner') 171 | assert len(merged_df) == len(scores_df), "No overlapping variants found between scores and CaQTL data" 172 | 173 | # Check tolerance between predicted and ground truth logfc 174 | tolerance = 1e-3 # Adjust as needed 175 | diff = abs(merged_df['logfc.mean'] - merged_df['pred.chrombpnet.encsr000emt.variantscore.logfc']) 176 | 177 | # This will always show in pytest output 178 | print(f"\nDifference stats:") 179 | print(diff) 180 | 181 | within_tolerance = (diff <= tolerance) 182 | 183 | # check that all the variants are within tolerance 184 | assert all(within_tolerance), f"Not all variants within tolerance: {within_tolerance.sum()}/{len(merged_df)}" 185 | 186 | else: 187 | pytest.skip("Cannot merge datasets - missing variant_id or dbsnp_rsid column") 188 | 189 | @pytest.mark.gpu 190 | @pytest.mark.oak 191 | def test_variant_scoring_per_chrom(self, out_dir, script_path_per_chrom, test_data_dir, genome_path, model_paths, chrom_sizes_path): 192 | """Test variant_scoring.per_chrom.py with real genome/model data (requires env vars and GPU)""" 193 | if not os.path.exists(script_path_per_chrom): 194 | pytest.skip(f"Script {script_path_per_chrom} not found") 195 | 196 | test_variants = os.path.join(test_data_dir, 'test.chrombpnet.tsv') 197 | if not os.path.exists(test_variants): 198 | pytest.skip("Test variant file not found") 199 | 200 | # Check inputs 201 | input_df = pd.read_csv(test_variants, sep='\t', header=None) 202 | chrms = input_df[0].unique() 203 | 204 | # Dictionary of number of variants per chromosome 205 | variant_counts = input_df[0].value_counts().to_dict() 206 | 207 | # Run for fold 0 208 | model_path = model_paths[0] 209 | output_prefix = os.path.join(out_dir, f"fold_0") 210 | 211 | cmd = [ 212 | sys.executable, script_path_per_chrom, 213 | '--list', test_variants, 214 | '--genome', genome_path, 215 | '--model', model_path, 216 | '--out_prefix', output_prefix, 217 | '--chrom_sizes', chrom_sizes_path, 218 | '--num_shuf', '2', # Use a small number for testing 219 | '--schema', 'chrombpnet', 220 | '--no_hdf5' # Skip HDF5 output for faster testing 221 | ] 222 | 223 | result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) 224 | 225 | if result.returncode != 0: 226 | print(f"STDOUT: {result.stdout}") 227 | print(f"STDERR: {result.stderr}") 228 | 229 | # Check if it completed successfully 230 | assert result.returncode == 0, f"Script failed: {result.stderr}" 231 | 232 | # Check output files exist 233 | df_list = [] 234 | for chr in chrms: 235 | output_file = f"{output_prefix}.{chr}.variant_scores.tsv" 236 | assert os.path.exists(output_file), f"Output file for {chr} not created" 237 | 238 | df = pd.read_csv(output_file, sep='\t') 239 | assert 'logfc' in df.columns, "Missing logfc column" 240 | assert 'jsd' in df.columns, "Missing jsd column" 241 | 242 | # Check that we have the right number of variants 243 | expected_count = variant_counts.get(chr, 0) 244 | assert len(df) == expected_count, f"Output for {chr} has {len(df)} variants, expected {expected_count}" 245 | 246 | df_list.append(df) 247 | 248 | @pytest.mark.gpu 249 | @pytest.mark.oak 250 | def test_merge_chroms(self, out_dir, script_path_per_chrom, test_data_dir, genome_path, model_paths, chrom_sizes_path): 251 | """Test variant_scoring.per_chrom.py with real genome/model data (requires env vars and GPU)""" 252 | if not os.path.exists(script_path_per_chrom): 253 | pytest.skip(f"Script {script_path_per_chrom} not found") 254 | 255 | test_variants = os.path.join(test_data_dir, 'test.chrombpnet.tsv') 256 | if not os.path.exists(test_variants): 257 | pytest.skip("Test variant file not found") 258 | 259 | # Check inputs 260 | input_df = pd.read_csv(test_variants, sep='\t', header=None) 261 | chrms = input_df[0].unique() 262 | 263 | # Dictionary of number of variants per chromosome 264 | variant_counts = input_df[0].value_counts().to_dict() 265 | 266 | # Run for fold 0 267 | model_path = model_paths[0] 268 | output_prefix = os.path.join(out_dir, f"merged_fold_0") 269 | 270 | cmd = [ 271 | sys.executable, script_path_per_chrom, 272 | '--list', test_variants, 273 | '--genome', genome_path, 274 | '--model', model_path, 275 | '--out_prefix', output_prefix, 276 | '--chrom_sizes', chrom_sizes_path, 277 | '--num_shuf', '2', # Use a small number for testing 278 | '--schema', 'chrombpnet', 279 | '--no_hdf5', # Skip HDF5 output for faster testing 280 | "--merge" 281 | ] 282 | 283 | result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) 284 | 285 | if result.returncode != 0: 286 | print(f"STDOUT: {result.stdout}") 287 | print(f"STDERR: {result.stderr}") 288 | 289 | # Check if it completed successfully 290 | assert result.returncode == 0, f"Script failed: {result.stderr}" 291 | 292 | # Load file 293 | output_file = f"{output_prefix}.variant_scores.tsv" 294 | df = pd.read_csv(output_file, sep='\t') 295 | 296 | assert 'logfc' in df.columns, "Missing logfc column" 297 | assert 'jsd' in df.columns, "Missing jsd column" 298 | 299 | # Check number of variants 300 | input_df = pd.read_csv(test_variants, sep='\t', header=None) 301 | assert len(df) == len(input_df), "Merged output has different number of variants than input" --------------------------------------------------------------------------------