├── CHANGELOG ├── LICENSE ├── README.md ├── bpnet ├── bpnetlite ├── __init__.py ├── attribute.py ├── bpnet.py ├── chrombpnet.py ├── io.py ├── logging.py ├── losses.py ├── marginalize.py └── performance.py ├── chrombpnet ├── example_jsons ├── README.md ├── bpnet_attribute_example.json ├── bpnet_fit_example.json ├── bpnet_marginalize_example.json ├── bpnet_pipeline_example.json ├── bpnet_predict_example.json ├── chrombpnet_fit_example.json └── chrombpnet_pipeline_example.json ├── examples ├── ENCSR000AKO.bed ├── ENCSR000AKO_minus.bigWig └── ENCSR000AKO_plus.bigWig ├── notebooks ├── bpnet │ ├── bpnet-fit.ipynb │ ├── bpnet_fit_example.json │ └── bpnet_predict_example.json └── rbpnet │ ├── rbpnet-fit.ipynb │ ├── rbpnet_fit_example.json │ └── rbpnet_predict_example.json └── setup.py /CHANGELOG: -------------------------------------------------------------------------------- 1 | =============== 2 | Release History 3 | =============== 4 | 5 | 6 | Version 0.8.1 7 | ============== 8 | 9 | Highlights 10 | ---------- 11 | 12 | - Added robustness toward other characters in the nucleotide alphabet. 13 | Anything not A, C, G, or T gets ignored. This robustness has been added to 14 | PeakGenerator and the command line arguments. 15 | 16 | 17 | Version 0.8.0 18 | ============== 19 | 20 | Highlights 21 | ---------- 22 | 23 | - When training ChromBPNet bias models from the command line and no loci 24 | are provided for training the bias model, the default is changed to inherit 25 | from `negatives` instead of `loci`. Colloquially, when the user provides a 26 | set of negatives and peak loci to train ChromBPNet, the default is now that 27 | the bias model will be trained on the negatives instead of incorrectly on 28 | the peaks. 29 | - When calculating the maximum number of reads a negative region can have 30 | when training the bias model, changed from using the minimum number of reads 31 | in the peaks to the 1st quantile of reads because it is more robust. 32 | 33 | 34 | Version 0.7.4 35 | ============== 36 | 37 | Highlights 38 | ---------- 39 | 40 | - Re-added the `attribute.py` file back in with a `deep_lift_shap` function 41 | that wraps tangermeme's function but passes in the layers that must be 42 | registered. 43 | 44 | 45 | Version 0.7.3 46 | ============== 47 | 48 | Highlights 49 | ---------- 50 | 51 | - Added `BasePairNet` in `bpnetlite.bpnet` which is an implementation of the 52 | model from the official `basepairmodels` 53 | - Added `BasePairNet.from_bpnet` to load TensorFlow-trained models from 54 | `basepairmodels` into the PyTorch wrapper 55 | - Removed a few dependencies that are no longer needed after using 56 | tangermeme 57 | 58 | 59 | 60 | Version 0.7.2 61 | ============== 62 | 63 | Highlights 64 | ---------- 65 | 66 | - Complete inclusion of tangermeme as the backend for operations 67 | - Remove the `predict` method for models in favor of `tangermeme.predict` 68 | - Remove attribute.py in favor of `tangermeme.deep_lift_shap` 69 | - Remove marginalization functions in favor of `tangermeme.marginalize` 70 | - Remove plotting functions in favor of `tangermeme.plot` 71 | - Alter the `bpnet` and `chrombpnet` command-line tools to account for 72 | these changes. 73 | - Add in `_Log` and `_Exp` as layers for the ChromBPNet model so that they 74 | can be registered as non-linear functions for `deep_lift_shap`. 75 | 76 | 77 | 78 | Version 0.7.1 79 | ============== 80 | 81 | Highlights 82 | ---------- 83 | 84 | - Begin inclusion of tangermeme into the backend. 85 | 86 | 87 | Version 0.7.0 88 | ============== 89 | 90 | Highlights 91 | ---------- 92 | 93 | - Changed the function name from `calculate_attributions` to `attribute` 94 | to be more in line with the `predict` and `marginalize` functions. The 95 | functionality and usage should be the same. 96 | - Changed the nomenclature from "interpret" to "attribute" to be more 97 | consistent with the name of the function and what is used colloquially. 98 | 99 | 100 | 101 | Version 0.6.0 102 | ============== 103 | 104 | Highlights 105 | ---------- 106 | 107 | - Replaced the negative sampling code with a simpler approach that only 108 | considers bins of signal rather than operates at bp resolution. This code 109 | is much faster and more robust but may produce slightly worse GC matches. 110 | - The negative sampling code now allows you to pass in a bigwig so that 111 | only regions that pass a threshold are selected. 112 | 113 | 114 | 115 | Version 0.5.7 116 | ============== 117 | 118 | Highlights 119 | ---------- 120 | 121 | - Changed the `warning_threshold` argument to only print a warning rather 122 | than end the process when the model exceeds it. 123 | - Added support for plotting annotations alongside `plot_attributions` 124 | - Fixed various minor bugs. 125 | 126 | 127 | 128 | Version 0.5.6 129 | ============== 130 | 131 | Highlights 132 | ---------- 133 | 134 | - Changed the shape of the returned one-hot encoded sequences to match 135 | the documentation. 136 | - Fixed an issue with dinucleotide shuffling when not all nucleotides 137 | are present. 138 | 139 | 140 | 141 | Version 0.5.5 142 | ============== 143 | 144 | Highlights 145 | ---------- 146 | 147 | - Fixed an issue with ChromBPNet reading. 148 | 149 | 150 | 151 | Version 0.5.4 152 | ============== 153 | 154 | Highlights 155 | ---------- 156 | 157 | - Added in reading of TensorFlow-formatted ChromBPNet models from the 158 | official repo using the `from_chrombpnet` commands to the BPNet and 159 | ChromBPNet objects. 160 | 161 | 162 | 163 | Version 0.5.2 164 | ============== 165 | 166 | Highlights 167 | ---------- 168 | 169 | - Fixed issue where non-linear operations in DeepLiftShap were not 170 | registered correctly and hence causing minor divergences. Through the 171 | use of an ugly wrapper object this has been fixed. 172 | - Added in `print_convergence_deltas` and `warning_threshold` to the 173 | `calculate_attributions` function and the `DeepLiftShap` object. The first 174 | will print convergence deltas for every example that gets explained and the 175 | second will raise a warning if the divergence is higher than it. 176 | 177 | 178 | 179 | Version 0.5.0 180 | ============== 181 | 182 | Highlights 183 | ---------- 184 | 185 | - Extended support for the `chrombpnet` command-line tool 186 | - Now has mirrored functionality of the `bpnet` command-line tool 187 | - `chrombpnet pipeline` now mirrors `bpnet pipeline` except that it will 188 | run each of the reports on each of the three models: the full ChromBPNet 189 | model, the accessibility model, and the bias model. It will train a bias 190 | model and an accessibility model if not provided. 191 | - Changed the ChromBPNet object to be compatible with the `bpnet` command 192 | options. 193 | - Fixed issue with attributions where performance would degrade over time. 194 | 195 | 196 | 197 | Version 0.4.0 198 | ============== 199 | 200 | Highlights 201 | ---------- 202 | 203 | - Extended support for the `bpnet` command-line tool 204 | - Added in `marginalize` command-line option for generating those reports 205 | - Added in `pipeline` command-line option for running a full pipeline from 206 | model training to inference, attribution, tfmodisco, and marginalization 207 | 208 | 209 | 210 | Version 0.3.0 211 | ============== 212 | 213 | Highlights 214 | ---------- 215 | 216 | - I forgot. 217 | 218 | 219 | 220 | Version 0.2.0 221 | ============== 222 | 223 | Highlights 224 | ---------- 225 | 226 | - Addition of a `ChromBPNet` model 227 | - Addition of an explicit, shared, `Logger` class 228 | - "Peak" semantics have been switched to "locus" semantics 229 | 230 | 231 | chrombpnet.py 232 | ------------- 233 | 234 | - Newly added. 235 | - This file contains the `ChromBPNet` class, which is a wrapper that 236 | takes in two BPNet objects: a pre-trained bias model, and an untrained 237 | accessibility model, and specifies the training procedure for training 238 | the accessibility model. 239 | 240 | 241 | io.py 242 | ----- 243 | 244 | - The semantics of "peaks", e.g. `extract_peaks`, has been changed to loci, 245 | e.g. `extract_loci`, and the associated keywords (now `loci` from `peaks`) 246 | can take in a list or tuple of files to interleave them. This means you 247 | can now train on peaks and background regions. 248 | 249 | 250 | logging.py 251 | ---------- 252 | 253 | - Newly added. 254 | - This file contains the `Logger` class which is a simple way to record 255 | and display statistics during training. 256 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jacob Schreiber 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # bpnet-lite 2 | 3 | [![PyPI Downloads](https://static.pepy.tech/badge/bpnet-lite)](https://pepy.tech/projects/bpnet-lite) 4 | 5 | > **Note** 6 | > IMPORTANT: bpnet-lite is not meant to replace the full service implementations of BPNet or ChromBPNet. Please see the official repositories for those projects for TensorFlow/Keras implementations of those models along with complete tutorials on how to use them effectively. 7 | 8 | bpnet-lite is a lightweight version of BPNet [[paper](https://www.nature.com/articles/s41588-021-00782-6) | [code](https://github.com/kundajelab/basepairmodels)] and ChromBPNet [paper | [code](https://github.com/kundajelab/chrombpnet)], containing PyTorch reference implementations of both models. Additionally, it contains efficient data loaders and common operations one would do with these trained models including calculating attributions, running TF-MoDISco, and performing marginalization experiments. These operations are wrapped in command-line tools for ease-of-use and organized in a pipeline command representing the standard workflow. This package is meant to be used for quickly exploring data sets using BPNet or ChromBPNet and as a springboard for prototyping new ideas that involve modifying the code. 9 | 10 | #### Installation 11 | 12 | `pip install bpnet-lite`. 13 | 14 | #### Data Preprocessing 15 | 16 | > **Note** 17 | > As of v0.9.0 you can now include .BAM and .tsv files in the JSONs for the bpnet-lite command-line tool and the conversion to bigWigs will be automatically performed using bam2bw. Because bam2bw is fast (around ~500k records/second) it is not always necessary to separately preprocess your data. 18 | 19 | BPNet and ChromBPNet models are both trained on read ends that have been mapped at basepair resolution (hence, the name). Accordingly, the data used for training is made up of integer counts with one count per read in the file (or two counts per fragment). Once you have used your favorite tool to align your FASTQ of reads to your genome of interest (we recommend ChroMAP), you should use [bam2bw](https://github.com/jmschrei/bam2bw) to convert your BAM or fragment tsv file to bigWig files. 20 | 21 | If you are using stranded data, e.g., ChIP-seq: 22 | 23 | ``` 24 | bam2bw .bam .bam ... -s .chrom.sizes/.fa -n -v 25 | ``` 26 | 27 | This command will create two bigWig files, one for the + strand and one for the - strand, using the name provided as the suffix. 28 | 29 | If you are using unstranded data, e.g., ATAC-seq: 30 | 31 | ``` 32 | bam2bw .bam .bam ... -s .chrom.sizes/.fa -n -v -u 33 | ``` 34 | 35 | If you have a file of fragments, usually formatted as a .tsv or .tsv.gz and coming from ATAC-seq or scATAC-seq data, you can use the `-f` flag to map both the start and end (end-1, specifically) instead of just the 5' end. 36 | 37 | ``` 38 | bam2bw .tsv.gz .tsv.gz ... -s .chrom.sizes/.fa -n -v -u -f 39 | ``` 40 | 41 | These tools require positive loci (usually peaks for the respective activity) and negative loci (usually GC-matched background sequences) for training. The positive loci must be provided from the user, potentially by applying a tool like MACS2 to your .BAM files. The negative loci can be calculated using a command-line tool in this package, described later, or by specifying in the JSON that `find_negatives: true`. 42 | 43 | ## BPNet 44 | 45 | ![image](https://github.com/jmschrei/bpnet-lite/assets/3916816/5c6e6f73-aedd-4256-8776-5ef57a728d5e) 46 | 47 | BPNet is a convolutional neural network that has been used to map nucleotide sequences to experimental readouts, e.g. ChIP-seq, ChIP-nexus, and ChIP-exo, and identify the driving motifs underlying these assays. Although these models achieve high predictive accuracy, their main purpose is to be interpreted using feature attribution methods to inspect the cis-regulatory code underlying the readouts being modeled. Specifically, when paired with a method like DeepLIFT/SHAP, they can be used to explain the driving motifs and syntax of those motifs underlying each signal peak in a readout. When looking across all peaks these attributions can be clustered using an algorithm like TF-MoDISco to identify repeated patterns. Finally, one can construct a variety of perturbations to reference sequence to identify variant effect or marginalize out background. 48 | 49 | ### BPNet Command Line Tools 50 | 51 | bpnet-lite comes with a command-line tool, `bpnet`, that supports the steps necessary for training and using BPNet models. Except for extracting GC-matched negatives, each command requires a JSON that contains the parameters, with examples of each in the `example_jsons` folder. See the README in that folder for exact parameters for each JSON. 52 | 53 | ``` 54 | bpnet negatives -i .bed -f .fa -b .bw -o matched_loci.bed -l 0.02 -w 2114 -v 55 | bpnet fit -p bpnet_fit_example.json 56 | bpnet predict -p bpnet_predict_example.json 57 | bpnet attribute -p bpnet_attribute_example.json 58 | bpnet marginalize -p bpnet_marginalize_example.json 59 | ``` 60 | 61 | Alternatively, one can use the `pipeline` command, whose purpose is to go all the way from the .BAM and .bed files to all the results of using BPNet without needing any hand-holding. This command handles the mapping of .BAM and .tsv files to bigWigs, the identification of GC-matched negatives, the training of the model, making predictions, calculating attributions, running [TF-MoDISco](https://github.com/jmschrei/tfmodisco-lite) and generating a report on the found motifs, and performing marginalizations. For each step (except the TF-MoDISco one), a JSON is generated to serve as a record for what the precise input to each step was, and to allow easy editing in case something has gone slightly wrong. 62 | 63 | ``` 64 | bpnet pipeline -p bpnet_pipeline_example.json 65 | ``` 66 | 67 | For a complete description of the pipeline JSON, see the `example_jsons` folded. However, it is extremely flexible. For example, a different set of sequences or loci can be used in each step, allowing one to train a model genome-wide and then apply it to a set of synthetic constructs in a separate FASTA. Alternatively, one can train the model using one reference genome and apply it to another reference genome. 68 | 69 | If you want to run the entire pipeline but find these JSONs daunting you can use the following command that takes the filepaths to the input data and fills in the default pipeline JSON for you. 70 | 71 | ``` 72 | bpnet pipeline-json -i <.sam, .bam, .tsv, .tsv.gz, or .bw> -c -s -l -n -o -m 73 | ``` 74 | 75 | This command can optionally take in `-f` if the data are fragments and `-u` if the data are unstranded, and `-i` and `-c` can be repeated. None of the above fields are required if your situation requires something more complicated, but if all fields are provided the `bpnet pipeline` command can be run directly on the JSON without modification. You should check the JSON to make sure that everything is correct, though, e.g., the right number of filters and layers in the model. 76 | 77 | ## ChromBPNet 78 | 79 | ![image](https://github.com/jmschrei/bpnet-lite/assets/3916816/e6f9bbdf-f107-4b3e-8b97-dc552af2239c) 80 | 81 | ChromBPNet extends the original modeling framework to DNase-seq and ATAC-seq experiments. A separate framework is necessary because the cutting enzymes used in these experiments, particularly the hyperactive Tn5 enzyme used in ATAC-seq experiments, have soft sequences preferences that can distort the observed readouts. Hence, it becomes necessary to train a small BPNet model to explicitly capture this soft sequence (the "bias model") bias before subsequently training a second BPNet model jointly with the frozen bias model to capture the true drivers of accessibility (the "accessibiity model"). Together, these models and the manner in which their predictions are combined are referred to as ChromBPNet. 82 | 83 | Generally, one can perform the same analyses using ChromBPNet as one can using BPNet. However, an important note is that the full ChromBPNet model faithfully represents the experimental readout -- bias and all -- and so for more inspection tasks, e.g. variant effect prediction and interpretation, one should use only the accessibility model. Because the accessibiity model itself is conceptually, and also literally implemented as, a BPNet model, one can run the same procedure and use the BPNet command-line tool using it. 84 | 85 | ### 86 | 87 | bpnet-lite comes with a second command-line tool, `chrombpnet`, that supports the steps necessary for training and using ChromBPNet models. These commands are used exactly the same way as the `bpnet` command-line tool with only minor changes to the parameters in the JSON. Note that the `predict`, `attribute` and `marginalize` commands will internally run their `bpnet` counterparts, but are still provided for convenience. 88 | 89 | ``` 90 | chrombpnet fit -p chrombpnet_fit_example.json 91 | chrombpnet predict -p chrombpnet_predict_example.json 92 | chrombpnet attribute -p chrombpnet_attribute_example.json 93 | chrombpnet marginalize -p chrombpnet_marginalize_example.json 94 | ``` 95 | 96 | Similarly to `bpnet`, one can run the entire pipeline of commands specified above in addition to also running TF-MoDISco and generating a report on the found motifs. Unlike `bpnet`, this command will run each of those steps for (1) the full ChromBPNet model, (2) the accessibility model alone, and (3) the bias model. 97 | 98 | ``` 99 | chrombpnet pipeline -p chrombpnet_pipeline_example.json 100 | ``` 101 | 102 | ## Python API 103 | 104 | If you'd rather train and use BPNet/ChromBPNet models programmatically, you can use the Python API. The command-line tool is made up of wrappers around these methods and functions, so please take a look if you'd like additional documentation on how to get started. 105 | 106 | The first step is loading data. Much like with the command-line tool, if you're using the built-in data loader then you need to specify where the FASTA containing sequences, a BED file containing loci and bigwig files to train on are. The signals need to be provided in a list and the index of each bigwig in the list will correspond to a model output. Optionally, you can also provide control bigwigs. See the BPNet paper for how these control bigwigs get used during training. 107 | 108 | ```python 109 | import torch 110 | 111 | from tangermeme.io import extract_loci 112 | from bpnetlite.io import PeakGenerator 113 | from bpnetlite import BPNet 114 | 115 | peaks = 'test/CTCF.peaks.bed' # A set of loci to train on. 116 | seqs = '../../oak/common/hg38/hg38.fa' # A set of sequences to train on 117 | signals = ['test/CTCF.plus.bw', 'test/CTCF.minus.bw'] # A set of bigwigs 118 | controls = ['test/CTCF.plus.ctl.bw', 'test/CTCF.minus.ctl.bw'] # A set of bigwigs 119 | ``` 120 | 121 | After specifying filepaths for each of these, you can create the data generator. If you have a set of chromosomes you'd like to use for training, you can pass those in as well. They must match exactly with the names of chromsomes given in the BED file. 122 | 123 | ```python 124 | training_chroms = ['chr{}'.format(i) for i in range(1, 17)] 125 | 126 | training_data = PeakGenerator(peaks, seqs, signals, controls, chroms=training_chroms) 127 | ``` 128 | 129 | The `PeakGenerator` function is a wrapper around several functions that extract data, pass them into a generator that applies shifts and shuffling, and pass that generator into a PyTorch data loader object for use during training. The end result is an object that can be directly iterated over while training a bpnet-lite model. 130 | 131 | Although wrapping all that functionality is good for the training set, the validation set should remain constant during training. Hence, one should only use the `extract_loci` function that is the first step when handling the training data. 132 | 133 | ```python 134 | valid_chroms = ['chr{}'.format(i) for i in range(18, 23)] 135 | 136 | X_valid, y_valid, X_ctl_valid = extract_peaks(peaks, seqs, signals, controls, chroms=valid_chroms, max_jitter=0) 137 | ``` 138 | Note that this function can be used without control tracks and, in that case, will only return two arguments. Further, it can used with only a FASTA and will only return one argument in that case: the extracted sequences. 139 | 140 | Now, we can define the model. If you want to change the architecture, check out the documentation. 141 | 142 | ```python 143 | model = BPNet(n_outputs=2, n_control_tracks=2, trimming=(2114 - 1000) // 2).cuda() 144 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 145 | ``` 146 | 147 | And, finally, we can call the `fit_generator` method to train the model. This function is largely just a training loop that trains the profile head using the multinomial log-likelihood loss and the count head using the mean-squared error loss, but a benefit of this built-in method is that it outputs a tsv of the training statistics that you can redirect to a log file. 148 | 149 | ```python 150 | model.fit(training_data, optimizer, X_valid=X_valid, 151 | X_ctl_valid=X_ctl_valid, y_valid=y_valid) 152 | ``` 153 | 154 | Because `model` is a PyTorch object, it can be trained using a custom training loop in the same way any base PyTorch model can be trained if you'd prefer to do that. Likewise, if you'd prefer to use a custom data generator you can write your own and pass that into the `fit` function. 155 | -------------------------------------------------------------------------------- /bpnet: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # BPNet command-line tool 3 | # Author: Jacob Schreiber 4 | 5 | import os 6 | os.environ['TORCH_CUDNN_V8_API_ENABLED'] = '1' 7 | 8 | import sys 9 | import numpy 10 | import torch 11 | import argparse 12 | import subprocess 13 | 14 | 15 | from bpnetlite.io import PeakGenerator 16 | 17 | from bpnetlite.bpnet import BPNet 18 | from bpnetlite.bpnet import CountWrapper 19 | from bpnetlite.bpnet import ProfileWrapper 20 | from bpnetlite.bpnet import ControlWrapper 21 | 22 | from bpnetlite.chrombpnet import ChromBPNet 23 | 24 | from bpnetlite.attribute import deep_lift_shap 25 | from bpnetlite.marginalize import marginalization_report 26 | 27 | from tangermeme.io import extract_loci 28 | from tangermeme.match import extract_matching_loci 29 | from tangermeme.predict import predict 30 | 31 | import json 32 | 33 | torch.backends.cudnn.benchmark = True 34 | 35 | 36 | desc = """BPNet is an neural network primarily composed of dilated residual 37 | convolution layers for modeling the associations between biological 38 | sequences and biochemical readouts. This tool will take in a fasta 39 | file for the sequence, a bed file for signal peak locations, and bigWig 40 | files for the signal to predict and the control signal, and train a 41 | BPNet model for you.""" 42 | 43 | _help = """Must be either 'negatives', 'fit', 'predict', 'attribute', 44 | 'marginalize', or 'pipeline'.""" 45 | 46 | 47 | # Read in the arguments 48 | parser = argparse.ArgumentParser(description=desc) 49 | subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd') 50 | 51 | negatives_parser = subparsers.add_parser("negatives", 52 | help="Sample GC-matched negatives.") 53 | negatives_parser.add_argument("-i", "--peaks", required=True, 54 | help="Peak bed file.") 55 | negatives_parser.add_argument("-f", "--fasta", help="Genome FASTA file.") 56 | negatives_parser.add_argument("-b", "--bigwig", help="Optional signal bigwig.") 57 | negatives_parser.add_argument("-o", "--output", required=True, 58 | help="Output bed file.") 59 | negatives_parser.add_argument("-l", "--bin_width", type=float, default=0.02, 60 | help="GC bin width to match.") 61 | negatives_parser.add_argument("-n", "--max_n_perc", type=float, default=0.1, 62 | help="Maximum percentage of Ns allowed in each locus.") 63 | negatives_parser.add_argument("-a", "--beta", type=float, default=0.5, 64 | help="Multiplier on the minimum counts in peaks.") 65 | negatives_parser.add_argument("-w", "--in_window", type=int, default=2114, 66 | help="Width for calculating GC content.") 67 | negatives_parser.add_argument("-x", "--out_window", type=int, default=1000, 68 | help="Non-overlapping stride to use for loci.") 69 | negatives_parser.add_argument("-v", "--verbose", default=False, 70 | action='store_true') 71 | 72 | fit_parser = subparsers.add_parser("fit", help="Fit a BPNet model.") 73 | fit_parser.add_argument("-p", "--parameters", type=str, required=True, 74 | help="A JSON file containing the parameters for fitting the model.") 75 | 76 | predict_parser = subparsers.add_parser("predict", 77 | help="Make predictions using a trained BPNet model.") 78 | predict_parser.add_argument("-p", "--parameters", type=str, required=True, 79 | help="A JSON file containing the parameters for making predictions.") 80 | 81 | attribute_parser = subparsers.add_parser("attribute", 82 | help="Calculate attributions using a trained BPNet model.") 83 | attribute_parser.add_argument("-p", "--parameters", type=str, required=True, 84 | help="A JSON file containing the parameters for calculating attributions.") 85 | 86 | marginalize_parser = subparsers.add_parser("marginalize", 87 | help="Run marginalizations given motifs.") 88 | marginalize_parser.add_argument("-p", "--parameters", type=str, required=True, 89 | help="A JSON file containing the parameters for calculating attributions.") 90 | 91 | pipeline_parser = subparsers.add_parser("pipeline", 92 | help="Run each step on the given files.") 93 | pipeline_parser.add_argument("-p", "--parameters", type=str, required=True, 94 | help="A JSON file containing the parameters used for each step.") 95 | 96 | 97 | ### 98 | # Default Parameters 99 | ### 100 | 101 | training_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 102 | 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 103 | 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] 104 | 105 | validation_chroms = ['chr8', 'chr10'] 106 | 107 | default_fit_parameters = { 108 | 'n_filters': 64, 109 | 'n_layers': 8, 110 | 'profile_output_bias': True, 111 | 'count_output_bias': True, 112 | 'name': None, 113 | 'batch_size': 64, 114 | 'in_window': 2114, 115 | 'out_window': 1000, 116 | 'max_jitter': 128, 117 | 'reverse_complement': True, 118 | 'max_epochs': 50, 119 | 'validation_iter': 100, 120 | 'lr': 0.001, 121 | 'alpha': 1, 122 | 'dtype': 'float32', 123 | 'early_stopping': None, 124 | 'verbose': False, 125 | 126 | 'min_counts': 0, 127 | 'max_counts': 99999999, 128 | 129 | 'training_chroms': training_chroms, 130 | 'validation_chroms': validation_chroms, 131 | 'sequences': None, 132 | 'loci': None, 133 | 'signals': None, 134 | 'controls': None, 135 | 'random_state': None 136 | } 137 | 138 | default_predict_parameters = { 139 | 'batch_size': 64, 140 | 'in_window': 2114, 141 | 'out_window': 1000, 142 | 'verbose': False, 143 | 'chroms': training_chroms, 144 | 'sequences': None, 145 | 'loci': None, 146 | 'controls': None, 147 | 'model': None, 148 | 'profile_filename': 'y_profile.npz', 149 | 'counts_filename': 'y_counts.npz' 150 | } 151 | 152 | default_attribute_parameters = { 153 | 'batch_size': 64, 154 | 'in_window': 2114, 155 | 'out_window': 1000, 156 | 'verbose': False, 157 | 'chroms': training_chroms, 158 | 'sequences': None, 159 | 'loci': None, 160 | 'model': None, 161 | 'output': 'counts', 162 | 'ohe_filename': 'ohe.npz', 163 | 'attr_filename': 'attr.npz', 164 | 'n_shuffles':20, 165 | 'random_state':0, 166 | 'warning_threshold':1e-4 167 | } 168 | 169 | default_marginalize_parameters = { 170 | 'batch_size': 64, 171 | 'in_window': 2114, 172 | 'out_window': 1000, 173 | 'verbose': False, 174 | 'chroms': training_chroms, 175 | 'sequences': None, 176 | 'motifs': None, 177 | 'loci': None, 178 | 'n_loci': None, 179 | 'shuffle': False, 180 | 'model': None, 181 | 'output_filename':'marginalize/', 182 | 'random_state':0, 183 | 'minimal': True 184 | } 185 | 186 | default_pipeline_parameters = { 187 | # Model architecture parameters 188 | 'n_filters': 64, 189 | 'n_layers': 8, 190 | 'profile_output_bias': True, 191 | 'count_output_bias': True, 192 | 'in_window': 2114, 193 | 'out_window': 1000, 194 | 'name': None, 195 | 'model': None, 196 | 'early_stopping': None, 197 | 'verbose': False, 198 | 199 | # Data parameters 200 | 'batch_size': 64, 201 | 'max_jitter': 128, 202 | 'reverse_complement': True, 203 | 'max_epochs': 50, 204 | 'validation_iter': 100, 205 | 'lr': 0.001, 206 | 'alpha': 1, 207 | 'verbose': False, 208 | 'min_counts': 0, 209 | 'max_counts': 99999999, 210 | 211 | 'sequences': None, 212 | 'loci': None, 213 | 'signals': None, 214 | 'controls': None, 215 | 'find_negatives': False, 216 | 217 | # Fit parameters 218 | 'fit_parameters': { 219 | 'batch_size': 64, 220 | 'training_chroms': training_chroms, 221 | 'sequences': None, 222 | 'loci': None, 223 | 'signals': None, 224 | 'controls': None, 225 | 'verbose': None, 226 | 'random_state': None, 227 | }, 228 | 229 | # Predict parameters 230 | 'predict_parameters': { 231 | 'batch_size': 64, 232 | 'chroms': validation_chroms, 233 | 'profile_filename': None, 234 | 'counts_filename': None, 235 | 'sequences': None, 236 | 'loci': None, 237 | 'signals': None, 238 | 'controls': None, 239 | 'verbose': None, 240 | }, 241 | 242 | 243 | # Attribution parameters 244 | 'attribute_parameters': { 245 | 'batch_size': 64, 246 | 'chroms': validation_chroms, 247 | 'output': 'counts', 248 | 'loci': None, 249 | 'ohe_filename': None, 250 | 'attr_filename': None, 251 | 'n_shuffles': None, 252 | 'warning_threshold':1e-4, 253 | 'random_state': None, 254 | 'verbose': None 255 | 256 | }, 257 | 258 | # Modisco parameters 259 | 'modisco_motifs_parameters': { 260 | 'n_seqlets': 100000, 261 | 'output_filename': None, 262 | 'verbose': None 263 | }, 264 | 265 | # Modisco report parameters 266 | 'modisco_report_parameters': { 267 | 'motifs': None, 268 | 'output_folder': None, 269 | 'verbose': None 270 | }, 271 | 272 | # Marginalization parameters 273 | 'marginalize_parameters': { 274 | 'loci': None, 275 | 'n_loci': 100, 276 | 'batch_size': 64, 277 | 'shuffle': False, 278 | 'random_state': None, 279 | 'output_folder': None, 280 | 'motifs': None, 281 | 'minimal': True, 282 | 'verbose': None 283 | } 284 | } 285 | 286 | 287 | ### 288 | # Commands 289 | ### 290 | 291 | 292 | def _extract_set(parameters, defaults, name): 293 | subparameters = { 294 | key: parameters.get(key, None) for key in defaults if key in parameters 295 | } 296 | 297 | for parameter, value in parameters[name].items(): 298 | if value is not None: 299 | subparameters[parameter] = value 300 | 301 | return subparameters 302 | 303 | def _check_set(parameters, parameter, value): 304 | if parameters.get(parameter, None) == None: 305 | parameters[parameter] = value 306 | 307 | 308 | def merge_parameters(parameters, default_parameters): 309 | """Merge the provided parameters with the default parameters. 310 | 311 | 312 | Parameters 313 | ---------- 314 | parameters: str 315 | Name of the JSON folder with the provided parameters 316 | 317 | default_parameters: dict 318 | The default parameters for the operation. 319 | 320 | 321 | Returns 322 | ------- 323 | params: dict 324 | The merged set of parameters. 325 | """ 326 | 327 | with open(parameters, "r") as infile: 328 | parameters = json.load(infile) 329 | 330 | unset_parameters = ("controls", "warning_threshold", "early_stopping") 331 | for parameter, value in default_parameters.items(): 332 | if parameter not in parameters: 333 | if value is None and parameter not in unset_parameters: 334 | raise ValueError("Must provide value for '{}'".format(parameter)) 335 | 336 | parameters[parameter] = value 337 | 338 | return parameters 339 | 340 | 341 | # Pull the arguments 342 | args = parser.parse_args() 343 | 344 | 345 | ########## 346 | # NEGATIVES 347 | ########## 348 | 349 | if args.cmd == 'negatives': 350 | # Extract regions that match the GC content of the peaks 351 | matched_loci = extract_matching_loci( 352 | loci=args.peaks, 353 | fasta=args.fasta, 354 | gc_bin_width=args.bin_width, 355 | max_n_perc=args.max_n_perc, 356 | bigwig=args.bigwig, 357 | signal_beta=args.beta, 358 | in_window=args.in_window, 359 | out_window=args.out_window, 360 | chroms=None, 361 | verbose=args.verbose 362 | ) 363 | 364 | matched_loci.to_csv(args.output, header=False, sep='\t', index=False) 365 | 366 | 367 | ########## 368 | # FIT 369 | ########## 370 | 371 | if args.cmd == "fit": 372 | parameters = merge_parameters(args.parameters, default_fit_parameters) 373 | 374 | ### 375 | 376 | training_data = PeakGenerator( 377 | loci=parameters['loci'], 378 | sequences=parameters['sequences'], 379 | signals=parameters['signals'], 380 | controls=parameters['controls'], 381 | chroms=parameters['training_chroms'], 382 | in_window=parameters['in_window'], 383 | out_window=parameters['out_window'], 384 | max_jitter=parameters['max_jitter'], 385 | reverse_complement=parameters['reverse_complement'], 386 | min_counts=parameters['min_counts'], 387 | max_counts=parameters['max_counts'], 388 | random_state=parameters['random_state'], 389 | batch_size=parameters['batch_size'], 390 | verbose=parameters['verbose'] 391 | ) 392 | 393 | valid_data = extract_loci( 394 | sequences=parameters['sequences'], 395 | signals=parameters['signals'], 396 | in_signals=parameters['controls'], 397 | loci=parameters['loci'], 398 | chroms=parameters['validation_chroms'], 399 | in_window=parameters['in_window'], 400 | out_window=parameters['out_window'], 401 | max_jitter=0, 402 | ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 403 | verbose=parameters['verbose'] 404 | ) 405 | 406 | if parameters['controls'] is not None: 407 | valid_sequences, valid_signals, valid_controls = valid_data 408 | n_control_tracks = len(parameters['controls']) 409 | else: 410 | valid_sequences, valid_signals = valid_data 411 | valid_controls = None 412 | n_control_tracks = 0 413 | 414 | trimming = (parameters['in_window'] - parameters['out_window']) // 2 415 | 416 | model = BPNet(n_filters=parameters['n_filters'], 417 | n_layers=parameters['n_layers'], 418 | n_outputs=len(parameters['signals']), 419 | n_control_tracks=n_control_tracks, 420 | profile_output_bias=parameters['profile_output_bias'], 421 | count_output_bias=parameters['count_output_bias'], 422 | alpha=parameters['alpha'], 423 | trimming=trimming, 424 | name=parameters['name'], 425 | verbose=parameters['verbose']).cuda() 426 | 427 | optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr']) 428 | 429 | if parameters['verbose']: 430 | print("Training Set Size: ", training_data.dataset.sequences.shape[0]) 431 | print("Validation Set Size: ", valid_sequences.shape[0]) 432 | 433 | model.fit(training_data, optimizer, X_valid=valid_sequences, 434 | X_ctl_valid=valid_controls, y_valid=valid_signals, 435 | max_epochs=parameters['max_epochs'], 436 | validation_iter=parameters['validation_iter'], 437 | batch_size=parameters['batch_size'], 438 | early_stopping=parameters['early_stopping'], 439 | dtype=parameters['dtype']) 440 | 441 | 442 | ########## 443 | # PREDICT 444 | ########## 445 | 446 | elif args.cmd == 'predict': 447 | parameters = merge_parameters(args.parameters, default_predict_parameters) 448 | 449 | ### 450 | 451 | model = torch.load(parameters['model'], weights_only=False).cuda() 452 | 453 | examples = extract_loci( 454 | sequences=parameters['sequences'], 455 | in_signals=parameters['controls'], 456 | loci=parameters['loci'], 457 | chroms=parameters['chroms'], 458 | in_window=parameters['in_window'], 459 | out_window=parameters['out_window'], 460 | max_jitter=0, 461 | ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 462 | verbose=parameters['verbose'] 463 | ) 464 | 465 | if parameters['controls'] == None: 466 | X = examples 467 | if model.n_control_tracks > 0: 468 | X_ctl = torch.zeros(X.shape[0], model.n_control_tracks, X.shape[-1]) 469 | else: 470 | X_ctl = None 471 | else: 472 | X, X_ctl = examples 473 | 474 | if X_ctl is not None: 475 | X_ctl = (X_ctl,) 476 | 477 | y_profiles, y_counts = predict(model, X, args=X_ctl, 478 | batch_size=parameters['batch_size'], device='cuda', 479 | verbose=parameters['verbose']) 480 | 481 | numpy.savez_compressed(parameters['profile_filename'], y_profiles) 482 | numpy.savez_compressed(parameters['counts_filename'], y_counts) 483 | 484 | 485 | ########## 486 | # ATTRIBUTE 487 | ########## 488 | 489 | elif args.cmd == 'attribute': 490 | parameters = merge_parameters(args.parameters, default_attribute_parameters) 491 | 492 | ### 493 | 494 | model = torch.load(parameters['model'], weights_only=False).cuda() 495 | 496 | dtype = torch.float32 497 | if parameters['output'] == 'profile' or isinstance(model, ChromBPNet): 498 | dtype = torch.float64 499 | 500 | X = extract_loci( 501 | sequences=parameters['sequences'], 502 | loci=parameters['loci'], 503 | chroms=parameters['chroms'], 504 | max_jitter=0, 505 | ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 506 | verbose=parameters['verbose'] 507 | ) 508 | 509 | X = X[X.sum(dim=(1, 2)) == X.shape[-1]] 510 | 511 | model = ControlWrapper(model) 512 | if parameters['output'] == 'counts': 513 | wrapper = CountWrapper(model) 514 | elif parameters['output'] == 'profile': 515 | wrapper = ProfileWrapper(model) 516 | else: 517 | raise ValueError("output must be either `counts` or `profile`.") 518 | 519 | X_attr = deep_lift_shap(wrapper.type(dtype), X.type(dtype), 520 | hypothetical=True, 521 | n_shuffles=parameters['n_shuffles'], 522 | batch_size=parameters['batch_size'], 523 | random_state=parameters['random_state'], 524 | verbose=parameters['verbose'], 525 | warning_threshold=parameters['warning_threshold']) 526 | 527 | numpy.savez_compressed(parameters['ohe_filename'], X) 528 | numpy.savez_compressed(parameters['attr_filename'], X_attr) 529 | 530 | 531 | ########## 532 | # MARGINALIZE 533 | ########## 534 | 535 | elif args.cmd == 'marginalize': 536 | parameters = merge_parameters(args.parameters, 537 | default_marginalize_parameters) 538 | 539 | ### 540 | 541 | model = torch.load(parameters['model'], weights_only=False).cuda() 542 | model = ControlWrapper(model) 543 | 544 | X = extract_loci( 545 | sequences=parameters['sequences'], 546 | loci=parameters['loci'], 547 | chroms=parameters['chroms'], 548 | max_jitter=0, 549 | ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 550 | n_loci=parameters['n_loci'], 551 | verbose=parameters['verbose'] 552 | ).float() 553 | 554 | if parameters['shuffle'] == True: 555 | idxs = numpy.arange(X.shape[0]) 556 | numpy.random.shuffle(idxs) 557 | X = X[idxs] 558 | 559 | if parameters['n_loci'] is not None: 560 | X = X[:parameters['n_loci']] 561 | 562 | marginalization_report(model, parameters['motifs'], X, 563 | parameters['output_filename'], batch_size=parameters['batch_size'], 564 | minimal=parameters['minimal'], verbose=parameters['verbose']) 565 | 566 | 567 | ########## 568 | # PIPLEINE 569 | ########## 570 | 571 | elif args.cmd == 'pipeline': 572 | parameters = merge_parameters(args.parameters, default_pipeline_parameters) 573 | pname = parameters['name'] 574 | 575 | """ 576 | ### Step 0.0: Convert from SAM/BAMs to bigwigs if provided 577 | if parameters['signals'][0].endswith(".bam"): 578 | args = [ 579 | "bam2bw", 580 | "-s", parameters['chrom_sizes'], 581 | "-n", pname, 582 | "-v", parameters['verbose'] 583 | ] 584 | 585 | if parameters["stranded"] == False: 586 | args += ["-u"] 587 | 588 | args += parameters['signals'] 589 | 590 | subprocess.run(args, check=True) 591 | 592 | if parameters["stranded"] == False: 593 | parameters['signals'] = [pname + ".bam"] 594 | else: 595 | parameters['signals'] = [pname + ".+.bam", pname + ".-.bam"] 596 | 597 | if parameters['controls'] is not None: 598 | if parameters['controls'][0].endswith(".bam"): 599 | args = [ 600 | "bam2bw", 601 | "-s", parameters['chrom_sizes'], 602 | "-n", pname, 603 | "-v", parameters['verbose'] 604 | ] 605 | 606 | if parameters["stranded"] == False: 607 | args += ["-u"] 608 | 609 | args += parameters['controls'] 610 | 611 | subprocess.run(args, check=True) 612 | 613 | if parameters["stranded"] == False: 614 | parameters['controls'] = [pname + ".bam"] 615 | else: 616 | parameters['controls'] = [pname + ".+.bam", pname + ".-.bam"] 617 | 618 | ### Step 0.1: Identify GC-matched negative regions 619 | if parameters['find_negatives'] == True: 620 | args = [ 621 | "bpnet", "negatives", 622 | "-i", parameters["loci"][0], 623 | "-f", parameters["sequences"], 624 | "-o", pname + ".negatives.bed" 625 | ] 626 | 627 | if parameters['verbose']: 628 | args += ['-v'] 629 | 630 | parameters["loci"] += [pname + ".negatives.bed"] 631 | parameters["negatives"] = pname + ".negatives.bed" 632 | 633 | subprocess.run(args, check=True) 634 | """ 635 | 636 | 637 | ### Step 1: Fit a BPNet model to the provided data 638 | if parameters['verbose']: 639 | print("Step 1: Fitting a BPNet model") 640 | 641 | fit_parameters = _extract_set(parameters, default_fit_parameters, 642 | 'fit_parameters') 643 | 644 | if parameters.get('model', None) == None: 645 | name = pname + '.bpnet.fit.json' 646 | parameters['model'] = pname + '.torch' 647 | 648 | with open(name, 'w') as outfile: 649 | outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4)) 650 | 651 | subprocess.run(["bpnet", "fit", "-p", name], check=True) 652 | 653 | 654 | ### Step 2: Make predictions for the entire validation set 655 | if parameters['verbose']: 656 | print("\nStep 2: Making predictions") 657 | 658 | predict_parameters = _extract_set(parameters, 659 | default_predict_parameters, 'predict_parameters') 660 | _check_set(predict_parameters, 'profile_filename', pname+'.y_profiles.npz') 661 | _check_set(predict_parameters, 'counts_filename', pname+'.y_counts.npz') 662 | 663 | name = '{}.bpnet.predict.json'.format(parameters['name']) 664 | with open(name, 'w') as outfile: 665 | outfile.write(json.dumps(predict_parameters, sort_keys=True, indent=4)) 666 | 667 | subprocess.run(["bpnet", "predict", "-p", name], check=True) 668 | 669 | 670 | ### Step 3: Calculate attributions 671 | if parameters['verbose']: 672 | print("\nStep 3: Calculating attributions") 673 | 674 | 675 | attribute_parameters = _extract_set(parameters, 676 | default_attribute_parameters, 'attribute_parameters') 677 | _check_set(attribute_parameters, 'ohe_filename', pname+'.ohe.npz') 678 | _check_set(attribute_parameters, 'attr_filename', pname+'.attr.npz') 679 | 680 | name = '{}.bpnet.attribute.json'.format(parameters['name']) 681 | with open(name, 'w') as outfile: 682 | outfile.write(json.dumps(attribute_parameters, sort_keys=True, 683 | indent=4)) 684 | 685 | subprocess.run(["bpnet", "attribute", "-p", name], check=True) 686 | 687 | 688 | ### Step 4: Calculate tf-modisco motifs 689 | if parameters['verbose']: 690 | print("\nStep 4: TF-MoDISco motifs") 691 | 692 | modisco_parameters = parameters['modisco_motifs_parameters'] 693 | _check_set(modisco_parameters, "output_filename", 694 | pname+'_modisco_results.h5') 695 | 696 | cmd = "modisco motifs -s {} -a {} -n {} -o {}".format( 697 | attribute_parameters['ohe_filename'], 698 | attribute_parameters['attr_filename'], 699 | modisco_parameters['n_seqlets'], 700 | modisco_parameters['output_filename']) 701 | if modisco_parameters['verbose']: 702 | cmd += ' -v' 703 | 704 | subprocess.run(cmd.split(), check=True) 705 | 706 | 707 | ### Step 5: Generate the tf-modisco report 708 | modisco_name = "{}_modisco_results.h5".format(parameters['name']) 709 | 710 | report_parameters = parameters['modisco_report_parameters'] 711 | _check_set(report_parameters, "verbose", parameters["verbose"]) 712 | _check_set(report_parameters, "output_folder", pname+"_modisco/") 713 | 714 | if report_parameters['verbose']: 715 | print("\nStep 5: TF-MoDISco reports") 716 | 717 | subprocess.run(["modisco", "report", 718 | "-i", modisco_parameters['output_filename'], 719 | "-o", report_parameters['output_folder'], 720 | "-s", report_parameters['output_folder'], 721 | "-m", report_parameters['motifs'] 722 | ], check=True) 723 | 724 | 725 | ### Step 6: Marginalization experiments 726 | if parameters['verbose']: 727 | print("\nStep 6: Run marginalizations") 728 | 729 | marginalize_parameters = _extract_set(parameters, 730 | default_marginalize_parameters, "marginalize_parameters") 731 | _check_set(marginalize_parameters, "loci", parameters["negatives"]) 732 | _check_set(marginalize_parameters, 'output_filename', pname+"_marginalize/") 733 | 734 | name = '{}.bpnet.marginalize.json'.format(parameters['name']) 735 | 736 | with open(name, 'w') as outfile: 737 | outfile.write(json.dumps(marginalize_parameters, sort_keys=True, 738 | indent=4)) 739 | 740 | subprocess.run(["bpnet", "marginalize", "-p", name], check=True) 741 | -------------------------------------------------------------------------------- /bpnetlite/__init__.py: -------------------------------------------------------------------------------- 1 | # bpnet-lite 2 | # Author: Jacob Schreiber 3 | # Code adapted from Avanti Shrikumar and Ziga Avsec 4 | 5 | from .bpnet import BPNet 6 | from .chrombpnet import ChromBPNet 7 | 8 | __version__ = '0.8.1' 9 | -------------------------------------------------------------------------------- /bpnetlite/attribute.py: -------------------------------------------------------------------------------- 1 | # attribute.py 2 | # Author: Jacob Schreiber 3 | 4 | 5 | from bpnetlite.bpnet import _ProfileLogitScaling 6 | from bpnetlite.chrombpnet import _Log, _Exp 7 | 8 | from tangermeme.ersatz import dinucleotide_shuffle 9 | from tangermeme.deep_lift_shap import deep_lift_shap as t_deep_lift_shap 10 | from tangermeme.deep_lift_shap import _nonlinear 11 | 12 | 13 | def deep_lift_shap(model, X, args=None, target=0, batch_size=32, 14 | references=dinucleotide_shuffle, n_shuffles=20, return_references=False, 15 | hypothetical=False, warning_threshold=0.001, additional_nonlinear_ops=None, 16 | print_convergence_deltas=False, raw_outputs=False, device='cuda', 17 | random_state=None, verbose=False): 18 | """A wrapper that registers Chrom/BPNet's custom non-linearities. 19 | 20 | This function is just a wrapper for tangermeme's deep_lift_shap function 21 | except that it automatically registers the layers that are necessary for 22 | using BPNet models. Specifically, it registers a scaling that is necessary 23 | for calculating the profile attributions and also registers the logsumexp 24 | operation for counts when using the full ChromBPNet model. 25 | 26 | Other than automatically registering the non-linearities, this wrapper does 27 | not modify the tangermeme outputs or alter the inputs in any way. It is 28 | simply for convenience so you do not need to reach into bpnet-lite's 29 | internals each time you want to calculate attributions. 30 | 31 | 32 | Parameters 33 | ---------- 34 | model: torch.nn.Module 35 | A PyTorch model to use for making predictions. These models can take in 36 | any number of inputs and make any number of outputs. The additional 37 | inputs must be specified in the `args` parameter. 38 | 39 | X: torch.tensor, shape=(-1, len(alphabet), length) 40 | A set of one-hot encoded sequences to calculate attribution values 41 | for. 42 | 43 | args: tuple or None, optional 44 | An optional set of additional arguments to pass into the model. If 45 | provided, each element in the tuple or list is one input to the model 46 | and the element must be formatted to be the same batch size as `X`. If 47 | None, no additional arguments are passed into the forward function. 48 | Default is None. 49 | 50 | target: int, optional 51 | The output of the model to calculate gradients/attributions for. This 52 | will index the last dimension of the predictions. Default is 0. 53 | 54 | batch_size: int, optional 55 | The number of sequence-reference pairs to pass through DeepLiftShap at 56 | a time. Importantly, this is not the number of elements in `X` that 57 | are processed simultaneously (alongside ALL their references) but the 58 | total number of `X`-`reference` pairs that are processed. This means 59 | that if you are in a memory-limited setting where you cannot process 60 | all references for even a single sequence simultaneously that the 61 | work is broken down into doing only a few references at a time. Default 62 | is 32. 63 | 64 | references: func or torch.Tensor, optional 65 | If a function is passed in, this function is applied to each sequence 66 | with the provided random state and number of shuffles. This function 67 | should serve to transform a sequence into some form of signal-null 68 | background, such as by shuffling it. If a torch.Tensor is passed in, 69 | that tensor must have shape `(len(X), n_shuffles, *X.shape[1:])`, in 70 | that for each sequence a number of shuffles are provided. Default is 71 | the function `dinucleotide_shuffle`. 72 | 73 | n_shuffles: int, optional 74 | The number of shuffles to use if a function is given for `references`. 75 | If a torch.Tensor is provided, this number is ignored. Default is 20. 76 | 77 | return_references: bool, optional 78 | Whether to return the references that were generated during this 79 | process. Only use if `references` is not a torch.Tensor. Default is 80 | False. 81 | 82 | hypothetical: bool, optional 83 | Whether to return attributions for all possible characters at each 84 | position or only for the character that is actually at the sequence. 85 | Practically, whether to return the returned attributions from captum 86 | with the one-hot encoded sequence. Default is False. 87 | 88 | warning_threshold: float, optional 89 | A threshold on the convergence delta that will always raise a warning 90 | if the delta is larger than it. Normal deltas are in the range of 91 | 1e-6 to 1e-8. Note that convergence deltas are calculated on the 92 | gradients prior to the aggr_func being applied to them. Default 93 | is 0.001. 94 | 95 | additional_nonlinear_ops: dict or None, optional 96 | If additional nonlinear ops need to be added to the dictionary of 97 | operations that can be handled by DeepLIFT/SHAP, pass a dictionary here 98 | where the keys are class types and the values are the name of the 99 | function that handle that sort of class. Make sure that the signature 100 | matches those of `_nonlinear` and `_maxpool` above. This can also be 101 | used to overwrite the hard-coded operations by passing in a dictionary 102 | with overlapping key names. If None, do not add any additional 103 | operations. Default is None. 104 | 105 | print_convergence_deltas: bool, optional 106 | Whether to print the convergence deltas for each example when using 107 | DeepLiftShap. Default is False. 108 | 109 | raw_outputs: bool, optional 110 | Whether to return the raw outputs from the method -- in this case, 111 | the multipliers for each example-reference pair -- or the processed 112 | attribution values. Default is False. 113 | 114 | device: str or torch.device, optional 115 | The device to move the model and batches to when making predictions. If 116 | set to 'cuda' without a GPU, this function will crash and must be set 117 | to 'cpu'. Default is 'cuda'. 118 | 119 | random_state: int or None or numpy.random.RandomState, optional 120 | The random seed to use to ensure determinism. If None, the 121 | process is not deterministic. Default is None. 122 | 123 | verbose: bool, optional 124 | Whether to display a progress bar. Default is False. 125 | 126 | 127 | Returns 128 | ------- 129 | attributions: torch.tensor 130 | If `raw_outputs=False` (default), the attribution values with shape 131 | equal to `X`. If `raw_outputs=True`, the multipliers for each example- 132 | reference pair with shape equal to `(X.shape[0], n_shuffles, X.shape[1], 133 | X.shape[2])`. 134 | 135 | references: torch.tensor, optional 136 | The references used for each input sequence, with the shape 137 | (n_input_sequences, n_shuffles, 4, length). Only returned if 138 | `return_references = True`. 139 | """ 140 | 141 | return t_deep_lift_shap(model=model, X=X, args=args, target=target, 142 | batch_size=batch_size, references=references, n_shuffles=n_shuffles, 143 | return_references=return_references, hypothetical=hypothetical, 144 | warning_threshold=warning_threshold, 145 | additional_nonlinear_ops={ 146 | _ProfileLogitScaling: _nonlinear, 147 | _Log: _nonlinear, 148 | _Exp: _nonlinear 149 | }, 150 | print_convergence_deltas=print_convergence_deltas, 151 | raw_outputs=raw_outputs, device=device, random_state=random_state, 152 | verbose=verbose) 153 | -------------------------------------------------------------------------------- /bpnetlite/bpnet.py: -------------------------------------------------------------------------------- 1 | # bpnet.py 2 | # Author: Jacob Schreiber 3 | 4 | """ 5 | This module contains a reference implementation of BPNet that can be used 6 | or adapted for your own circumstances. The implementation takes in a 7 | stranded control track and makes predictions for stranded outputs. 8 | """ 9 | 10 | import h5py 11 | import time 12 | import numpy 13 | import torch 14 | 15 | from .losses import MNLLLoss 16 | from .losses import log1pMSELoss 17 | from .performance import pearson_corr 18 | from .performance import calculate_performance_measures 19 | from .logging import Logger 20 | 21 | from tqdm import tqdm 22 | 23 | from tangermeme.predict import predict 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | 28 | class ControlWrapper(torch.nn.Module): 29 | """This wrapper automatically creates a control track of all zeroes. 30 | 31 | This wrapper will check to see whether the model is expecting a control 32 | track (e.g., most BPNet-style models) and will create one with the expected 33 | shape. If no control track is expected then it will provide the normal 34 | output from the model. 35 | """ 36 | 37 | def __init__(self, model): 38 | super(ControlWrapper, self).__init__() 39 | self.model = model 40 | 41 | def forward(self, X, X_ctl=None): 42 | if X_ctl != None: 43 | return self.model(X, X_ctl) 44 | 45 | if self.model.n_control_tracks == 0: 46 | return self.model(X) 47 | 48 | X_ctl = torch.zeros(X.shape[0], self.model.n_control_tracks, 49 | X.shape[-1], dtype=X.dtype, device=X.device) 50 | return self.model(X, X_ctl) 51 | 52 | 53 | 54 | class _ProfileLogitScaling(torch.nn.Module): 55 | """This ugly class is necessary because of Captum. 56 | 57 | Captum internally registers classes as linear or non-linear. Because the 58 | profile wrapper performs some non-linear operations, those operations must 59 | be registered as such. However, the inputs to the wrapper are not the 60 | logits that are being modified in a non-linear manner but rather the 61 | original sequence that is subsequently run through the model. Hence, this 62 | object will contain all of the operations performed on the logits and 63 | can be registered. 64 | 65 | 66 | Parameters 67 | ---------- 68 | logits: torch.Tensor, shape=(-1, -1) 69 | The logits as they come out of a Chrom/BPNet model. 70 | """ 71 | 72 | def __init__(self): 73 | super(_ProfileLogitScaling, self).__init__() 74 | self.softmax = torch.nn.Softmax(dim=-1) 75 | 76 | def forward(self, logits): 77 | y_softmax = self.softmax(logits) 78 | return logits * y_softmax 79 | 80 | 81 | class ProfileWrapper(torch.nn.Module): 82 | """A wrapper class that returns transformed profiles. 83 | 84 | This class takes in a trained model and returns the weighted softmaxed 85 | outputs of the first dimension. Specifically, it takes the predicted 86 | "logits" and takes the dot product between them and the softmaxed versions 87 | of those logits. This is for convenience when using captum to calculate 88 | attribution scores. 89 | 90 | Parameters 91 | ---------- 92 | model: torch.nn.Module 93 | A torch model to be wrapped. 94 | """ 95 | 96 | def __init__(self, model): 97 | super(ProfileWrapper, self).__init__() 98 | self.model = model 99 | self.flatten = torch.nn.Flatten() 100 | self.scaling = _ProfileLogitScaling() 101 | 102 | def forward(self, X, X_ctl=None, **kwargs): 103 | logits = self.model(X, X_ctl, **kwargs)[0] 104 | logits = self.flatten(logits) 105 | logits = logits - torch.mean(logits, dim=-1, keepdims=True) 106 | return self.scaling(logits).sum(dim=-1, keepdims=True) 107 | 108 | 109 | class CountWrapper(torch.nn.Module): 110 | """A wrapper class that only returns the predicted counts. 111 | 112 | This class takes in a trained model and returns only the second output. 113 | For BPNet models, this means that it is only returning the count 114 | predictions. This is for convenience when using captum to calculate 115 | attribution scores. 116 | 117 | Parameters 118 | ---------- 119 | model: torch.nn.Module 120 | A torch model to be wrapped. 121 | """ 122 | 123 | def __init__(self, model): 124 | super(CountWrapper, self).__init__() 125 | self.model = model 126 | 127 | def forward(self, X, X_ctl=None, **kwargs): 128 | return self.model(X, X_ctl, **kwargs)[1] 129 | 130 | 131 | class BPNet(torch.nn.Module): 132 | """A basic BPNet model with stranded profile and total count prediction. 133 | 134 | This is a reference implementation for BPNet models. It exactly matches the 135 | architecture in the official ChromBPNet repository. It is very similar to 136 | the implementation in the official basepairmodels repository but differs in 137 | when the activation function is applied for the resifual layers. See the 138 | BasePairNet object below for an implementation that matches that repository. 139 | 140 | The model takes in one-hot encoded sequence, runs it through: 141 | 142 | (1) a single wide convolution operation 143 | 144 | THEN 145 | 146 | (2) a user-defined number of dilated residual convolutions 147 | 148 | THEN 149 | 150 | (3a) profile predictions done using a very wide convolution layer 151 | that also takes in stranded control tracks 152 | 153 | AND 154 | 155 | (3b) total count prediction done using an average pooling on the output 156 | from 2 followed by concatenation with the log1p of the sum of the 157 | stranded control tracks and then run through a dense layer. 158 | 159 | This implementation differs from the original BPNet implementation in 160 | two ways: 161 | 162 | (1) The model concatenates stranded control tracks for profile 163 | prediction as opposed to adding the two strands together and also then 164 | smoothing that track 165 | 166 | (2) The control input for the count prediction task is the log1p of 167 | the strand-wise sum of the control tracks, as opposed to the raw 168 | counts themselves. 169 | 170 | (3) A single log softmax is applied across both strands such that 171 | the logsumexp of both strands together is 0. Put another way, the 172 | two strands are concatenated together, a log softmax is applied, 173 | and the MNLL loss is calculated on the concatenation. 174 | 175 | (4) The count prediction task is predicting the total counts across 176 | both strands. The counts are then distributed across strands according 177 | to the single log softmax from 3. 178 | 179 | 180 | Parameters 181 | ---------- 182 | n_filters: int, optional 183 | The number of filters to use per convolution. Default is 64. 184 | 185 | n_layers: int, optional 186 | The number of dilated residual layers to include in the model. 187 | Default is 8. 188 | 189 | n_outputs: int, optional 190 | The number of profile outputs from the model. Generally either 1 or 2 191 | depending on if the data is unstranded or stranded. Default is 2. 192 | 193 | n_control_tracks: int, optional 194 | The number of control tracks to feed into the model. When predicting 195 | TFs, this is usually 2. When predicting accessibility, this is usualy 196 | 0. When 0, this input is removed from the model. Default is 2. 197 | 198 | alpha: float, optional 199 | The weight to put on the count loss. 200 | 201 | profile_output_bias: bool, optional 202 | Whether to include a bias term in the final profile convolution. 203 | Removing this term can help with attribution stability and will usually 204 | not affect performance. Default is True. 205 | 206 | count_output_bias: bool, optional 207 | Whether to include a bias term in the linear layer used to predict 208 | counts. Removing this term can help with attribution stability but 209 | may affect performance. Default is True. 210 | 211 | name: str or None, optional 212 | The name to save the model to during training. 213 | 214 | trimming: int or None, optional 215 | The amount to trim from both sides of the input window to get the 216 | output window. This value is removed from both sides, so the total 217 | number of positions removed is 2*trimming. 218 | 219 | verbose: bool, optional 220 | Whether to display statistics during training. Setting this to False 221 | will still save the file at the end, but does not print anything to 222 | screen during training. Default is True. 223 | """ 224 | 225 | def __init__(self, n_filters=64, n_layers=8, n_outputs=2, 226 | n_control_tracks=2, alpha=1, profile_output_bias=True, 227 | count_output_bias=True, name=None, trimming=None, verbose=True): 228 | super(BPNet, self).__init__() 229 | self.n_filters = n_filters 230 | self.n_layers = n_layers 231 | self.n_outputs = n_outputs 232 | self.n_control_tracks = n_control_tracks 233 | 234 | self.alpha = alpha 235 | self.name = name or "bpnet.{}.{}".format(n_filters, n_layers) 236 | self.trimming = trimming or 2 ** n_layers 237 | 238 | self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10) 239 | self.irelu = torch.nn.ReLU() 240 | 241 | self.rconvs = torch.nn.ModuleList([ 242 | torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, 243 | dilation=2**i) for i in range(1, self.n_layers+1) 244 | ]) 245 | self.rrelus = torch.nn.ModuleList([ 246 | torch.nn.ReLU() for i in range(1, self.n_layers+1) 247 | ]) 248 | 249 | self.fconv = torch.nn.Conv1d(n_filters+n_control_tracks, n_outputs, 250 | kernel_size=75, padding=37, bias=profile_output_bias) 251 | 252 | n_count_control = 1 if n_control_tracks > 0 else 0 253 | self.linear = torch.nn.Linear(n_filters+n_count_control, 1, 254 | bias=count_output_bias) 255 | 256 | self.logger = Logger(["Epoch", "Iteration", "Training Time", 257 | "Validation Time", "Training MNLL", "Training Count MSE", 258 | "Validation MNLL", "Validation Profile Pearson", 259 | "Validation Count Pearson", "Validation Count MSE", "Saved?"], 260 | verbose=verbose) 261 | 262 | 263 | def forward(self, X, X_ctl=None): 264 | """A forward pass of the model. 265 | 266 | This method takes in a nucleotide sequence X, a corresponding 267 | per-position value from a control track, and a per-locus value 268 | from the control track and makes predictions for the profile 269 | and for the counts. This per-locus value is usually the 270 | log(sum(X_ctl_profile)+1) when the control is an experimental 271 | read track but can also be the output from another model. 272 | 273 | Parameters 274 | ---------- 275 | X: torch.tensor, shape=(batch_size, 4, length) 276 | The one-hot encoded batch of sequences. 277 | 278 | X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length) 279 | A value representing the signal of the control at each position in 280 | the sequence. If no controls, pass in None. Default is None. 281 | 282 | Returns 283 | ------- 284 | y_profile: torch.tensor, shape=(batch_size, n_strands, out_length) 285 | The output predictions for each strand trimmed to the output 286 | length. 287 | """ 288 | 289 | start, end = self.trimming, X.shape[2] - self.trimming 290 | 291 | X = self.irelu(self.iconv(X)) 292 | for i in range(self.n_layers): 293 | X_conv = self.rrelus[i](self.rconvs[i](X)) 294 | X = torch.add(X, X_conv) 295 | 296 | if X_ctl is None: 297 | X_w_ctl = X 298 | else: 299 | X_w_ctl = torch.cat([X, X_ctl], dim=1) 300 | 301 | y_profile = self.fconv(X_w_ctl)[:, :, start:end] 302 | 303 | # counts prediction 304 | X = torch.mean(X[:, :, start-37:end+37], dim=2) 305 | if X_ctl is not None: 306 | X_ctl = torch.sum(X_ctl[:, :, start-37:end+37], dim=(1, 2)) 307 | X_ctl = X_ctl.unsqueeze(-1) 308 | X = torch.cat([X, torch.log(X_ctl+1)], dim=-1) 309 | 310 | y_counts = self.linear(X).reshape(X.shape[0], 1) 311 | return y_profile, y_counts 312 | 313 | 314 | def fit(self, training_data, optimizer, X_valid=None, X_ctl_valid=None, 315 | y_valid=None, max_epochs=100, batch_size=64, validation_iter=100, 316 | dtype='float32', early_stopping=None, verbose=True): 317 | """Fit the model to data and validate it periodically. 318 | 319 | This method controls the training of a BPNet model. It will fit the 320 | model to examples generated by the `training_data` DataLoader object 321 | and, if validation data is provided, will periodically validate the 322 | model against it and return those values. The periodicity can be 323 | controlled using the `validation_iter` parameter. 324 | 325 | Two versions of the model will be saved: the best model found during 326 | training according to the validation measures, and the final model 327 | at the end of training. Additionally, a log will be saved of the 328 | training and validation statistics, e.g. time and performance. 329 | 330 | 331 | Parameters 332 | ---------- 333 | training_data: torch.utils.data.DataLoader 334 | A generator that produces examples to train on. If n_control_tracks 335 | is greater than 0, must product two inputs, otherwise must produce 336 | only one input. 337 | 338 | optimizer: torch.optim.Optimizer 339 | An optimizer to control the training of the model. 340 | 341 | X_valid: torch.tensor or None, shape=(n, 4, 2114) 342 | A block of sequences to validate on periodically. If None, do not 343 | perform validation. Default is None. 344 | 345 | X_ctl_valid: torch.tensor or None, shape=(n, n_control_tracks, 2114) 346 | A block of control sequences to validate on periodically. If 347 | n_control_tracks is None, pass in None. Default is None. 348 | 349 | y_valid: torch.tensor or None, shape=(n, n_outputs, 1000) 350 | A block of signals to validate against. Must be provided if 351 | X_valid is also provided. Default is None. 352 | 353 | max_epochs: int 354 | The maximum number of epochs to train for, as measured by the 355 | number of times that `training_data` is exhausted. Default is 100. 356 | 357 | batch_size: int, optional 358 | The number of examples to include in each batch. Default is 64. 359 | 360 | dtype: str, optional 361 | Whether to use mixed precision and, if so, what dtype to use. If not 362 | using 'float32', recommended is to use 'bfloat16'. Default is 'float32'. 363 | 364 | validation_iter: int 365 | The number of batches to train on before validating against the 366 | entire validation set. When the validation set is large, this 367 | enables the total validating time to be small compared to the 368 | training time by only validating periodically. Default is 100. 369 | 370 | early_stopping: int or None, optional 371 | Whether to stop training early. If None, continue training until 372 | max_epochs is reached. If an integer, continue training until that 373 | number of `validation_iter` ticks has been hit without improvement 374 | in performance. Default is None. 375 | 376 | verbose: bool 377 | Whether to print out the training and evaluation statistics during 378 | training. Default is True. 379 | """ 380 | 381 | if X_valid is not None: 382 | y_valid_counts = y_valid.sum(dim=2) 383 | 384 | if X_ctl_valid is not None: 385 | X_ctl_valid = (X_ctl_valid,) 386 | 387 | dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype 388 | 389 | iteration = 0 390 | early_stop_count = 0 391 | best_loss = float("inf") 392 | self.logger.start() 393 | 394 | for epoch in range(max_epochs): 395 | tic = time.time() 396 | 397 | for data in training_data: 398 | if len(data) == 3: 399 | X, X_ctl, y = data 400 | X, X_ctl, y = X.cuda().float(), X_ctl.cuda(), y.cuda() 401 | else: 402 | X, y = data 403 | X, y = X.cuda().float(), y.cuda() 404 | X_ctl = None 405 | 406 | # Clear the optimizer and set the model to training mode 407 | optimizer.zero_grad() 408 | self.train() 409 | 410 | # Run forward pass 411 | with torch.autocast(device_type='cuda', dtype=dtype): 412 | y_profile, y_counts = self(X, X_ctl) 413 | 414 | y_profile = y_profile.reshape(y_profile.shape[0], -1) 415 | y_profile = torch.nn.functional.log_softmax(y_profile, dim=-1) 416 | 417 | y = y.reshape(y.shape[0], -1) 418 | y_ = y.sum(dim=-1).reshape(-1, 1) 419 | 420 | # Calculate the profile and count losses 421 | profile_loss = MNLLLoss(y_profile, y).mean() 422 | count_loss = log1pMSELoss(y_counts, y_).mean() 423 | 424 | # Extract the profile loss for logging 425 | profile_loss_ = profile_loss.item() 426 | count_loss_ = count_loss.item() 427 | 428 | # Mix losses together and update the model 429 | loss = profile_loss + self.alpha * count_loss 430 | loss.backward() 431 | optimizer.step() 432 | 433 | # Report measures if desired 434 | if verbose and iteration % validation_iter == 0: 435 | train_time = time.time() - tic 436 | 437 | with torch.no_grad(): 438 | self.eval() 439 | 440 | tic = time.time() 441 | y_profile, y_counts = predict(self, X_valid, 442 | args=X_ctl_valid, batch_size=batch_size, 443 | dtype=dtype, device='cuda') 444 | 445 | z = y_profile.shape 446 | y_profile = y_profile.reshape(y_profile.shape[0], -1) 447 | y_profile = torch.nn.functional.log_softmax(y_profile, 448 | dim=-1) 449 | y_profile = y_profile.reshape(*z) 450 | 451 | measures = calculate_performance_measures(y_profile, 452 | y_valid, y_counts, kernel_sigma=7, 453 | kernel_width=81, measures=['profile_mnll', 454 | 'profile_pearson', 'count_pearson', 'count_mse']) 455 | 456 | profile_corr = measures['profile_pearson'] 457 | count_corr = measures['count_pearson'] 458 | 459 | valid_loss = measures['profile_mnll'].mean() 460 | valid_loss += self.alpha * measures['count_mse'].mean() 461 | valid_time = time.time() - tic 462 | 463 | self.logger.add([epoch, iteration, train_time, 464 | valid_time, profile_loss_, count_loss_, 465 | measures['profile_mnll'].mean().item(), 466 | numpy.nan_to_num(profile_corr).mean(), 467 | numpy.nan_to_num(count_corr).mean(), 468 | measures['count_mse'].mean().item(), 469 | (valid_loss < best_loss).item()]) 470 | 471 | self.logger.save("{}.log".format(self.name)) 472 | 473 | if valid_loss < best_loss: 474 | torch.save(self, "{}.torch".format(self.name)) 475 | best_loss = valid_loss 476 | early_stop_count = 0 477 | else: 478 | early_stop_count += 1 479 | 480 | if early_stopping is not None and early_stop_count >= early_stopping: 481 | break 482 | 483 | iteration += 1 484 | 485 | if early_stopping is not None and early_stop_count >= early_stopping: 486 | break 487 | 488 | torch.save(self, "{}.final.torch".format(self.name)) 489 | 490 | 491 | @classmethod 492 | def from_chrombpnet_lite(cls, filename): 493 | """Loads a model from ChromBPNet-lite TensorFlow format. 494 | 495 | This method will load a ChromBPNet-lite model from TensorFlow format. 496 | Note that this is not the same as ChromBPNet format. Specifically, 497 | ChromBPNet-lite was a preceeding package that had a slightly different 498 | saving format, whereas ChromBPNet is the packaged version of that 499 | code that is applied at scale. 500 | 501 | This method does not load the entire ChromBPNet model. If that is 502 | the desired behavior, see the `ChromBPNet` object and its associated 503 | loading functions. Instead, this loads a single BPNet model -- either 504 | the bias model or the accessibility model, depending on what is encoded 505 | in the stored file. 506 | 507 | 508 | Parameters 509 | ---------- 510 | filename: str 511 | The name of the h5 file that stores the trained model parameters. 512 | 513 | 514 | Returns 515 | ------- 516 | model: BPNet 517 | A BPNet model compatible with this repository in PyTorch. 518 | """ 519 | 520 | h5 = h5py.File(filename, "r") 521 | w = h5['model_weights'] 522 | 523 | if 'model_1' in w.keys(): 524 | w = w['model_1'] 525 | bias = False 526 | else: 527 | bias = True 528 | 529 | k, b = 'kernel:0', 'bias:0' 530 | name = "conv1d_{}_1" if not bias else "conv1d_{0}/conv1d_{0}" 531 | 532 | layer_names = [] 533 | for layer_name in w.keys(): 534 | try: 535 | idx = int(layer_name.split("_")[1]) 536 | layer_names.append(idx) 537 | except: 538 | pass 539 | 540 | n_filters = w[name.format(1)][k].shape[2] 541 | n_layers = max(layer_names) - 2 542 | 543 | model = BPNet(n_layers=n_layers, n_filters=n_filters, n_outputs=1, 544 | n_control_tracks=0, trimming=(2114-1000)//2) 545 | 546 | convert_w = lambda x: torch.nn.Parameter(torch.tensor( 547 | x[:]).permute(2, 1, 0)) 548 | convert_b = lambda x: torch.nn.Parameter(torch.tensor(x[:])) 549 | 550 | model.iconv.weight = convert_w(w[name.format(1)][k]) 551 | model.iconv.bias = convert_b(w[name.format(1)][b]) 552 | model.iconv.padding = 12 553 | 554 | for i in range(2, n_layers+2): 555 | model.rconvs[i-2].weight = convert_w(w[name.format(i)][k]) 556 | model.rconvs[i-2].bias = convert_b(w[name.format(i)][b]) 557 | 558 | model.fconv.weight = convert_w(w[name.format(n_layers+2)][k]) 559 | model.fconv.bias = convert_b(w[name.format(n_layers+2)][b]) 560 | model.fconv.padding = 12 561 | 562 | name = "logcounts_1" if not bias else "logcounts/logcounts" 563 | model.linear.weight = torch.nn.Parameter(torch.tensor(w[name][k][:].T)) 564 | model.linear.bias = convert_b(w[name][b]) 565 | return model 566 | 567 | 568 | @classmethod 569 | def from_chrombpnet(cls, filename): 570 | """Loads a model from ChromBPNet TensorFlow format. 571 | 572 | This method will load one of the components of a ChromBPNet model 573 | from TensorFlow format. Note that a full ChromBPNet model is made up 574 | of an accessibility model and a bias model and that this will load 575 | one of the two. Use `ChromBPNet.from_chrombpnet` to end up with the 576 | entire ChromBPNet model. 577 | 578 | 579 | Parameters 580 | ---------- 581 | filename: str 582 | The name of the h5 file that stores the trained model parameters. 583 | 584 | 585 | Returns 586 | ------- 587 | model: BPNet 588 | A BPNet model compatible with this repository in PyTorch. 589 | """ 590 | 591 | h5 = h5py.File(filename, "r") 592 | w = h5['model_weights'] 593 | 594 | if 'bpnet_1conv' in w.keys(): 595 | prefix = "" 596 | else: 597 | prefix = "wo_bias_" 598 | 599 | namer = lambda prefix, suffix: '{0}{1}/{0}{1}'.format(prefix, suffix) 600 | k, b = 'kernel:0', 'bias:0' 601 | 602 | n_layers = 0 603 | for layer_name in w.keys(): 604 | try: 605 | idx = int(layer_name.split("_")[-1].replace("conv", "")) 606 | n_layers = max(n_layers, idx) 607 | except: 608 | pass 609 | 610 | name = namer(prefix, "bpnet_1conv") 611 | n_filters = w[name][k].shape[2] 612 | 613 | model = BPNet(n_layers=n_layers, n_filters=n_filters, n_outputs=1, 614 | n_control_tracks=0, trimming=(2114-1000)//2) 615 | 616 | convert_w = lambda x: torch.nn.Parameter(torch.tensor( 617 | x[:]).permute(2, 1, 0)) 618 | convert_b = lambda x: torch.nn.Parameter(torch.tensor(x[:])) 619 | 620 | iname = namer(prefix, 'bpnet_1st_conv') 621 | 622 | model.iconv.weight = convert_w(w[iname][k]) 623 | model.iconv.bias = convert_b(w[iname][b]) 624 | model.iconv.padding = ((21 - 1) // 2,) 625 | 626 | for i in range(1, n_layers+1): 627 | lname = namer(prefix, 'bpnet_{}conv'.format(i)) 628 | 629 | model.rconvs[i-1].weight = convert_w(w[lname][k]) 630 | model.rconvs[i-1].bias = convert_b(w[lname][b]) 631 | 632 | prefix = prefix + "bpnet_" if prefix != "" else "" 633 | 634 | fname = namer(prefix, 'prof_out_precrop') 635 | model.fconv.weight = convert_w(w[fname][k]) 636 | model.fconv.bias = convert_b(w[fname][b]) 637 | model.fconv.padding = ((75 - 1) // 2,) 638 | 639 | name = namer(prefix, "logcount_predictions") 640 | model.linear.weight = torch.nn.Parameter(torch.tensor(w[name][k][:].T)) 641 | model.linear.bias = convert_b(w[name][b]) 642 | return model 643 | 644 | 645 | class BasePairNet(torch.nn.Module): 646 | """A BPNet implementation matching that in basepairmodels 647 | 648 | This is a BPNet implementation that matches the one in basepairmodels and 649 | can be used to load models trained from that repository, e.g., those trained 650 | as part of the atlas project. The architecture of the model is identical to 651 | `BPNet` except that output from the residual layers is added to the 652 | pre-activation outputs from the previous layer, rather than to the 653 | post-activation outputs from the previous layer. Additionally, the count 654 | prediction head takes the sum of the control track counts, adds two instead 655 | of one, and then takes the log. Neither detail dramatically changes 656 | performance of the model but is necessary to account for when loading 657 | trained models. 658 | 659 | 660 | Parameters 661 | ---------- 662 | n_filters: int, optional 663 | The number of filters to use per convolution. Default is 64. 664 | 665 | n_layers: int, optional 666 | The number of dilated residual layers to include in the model. 667 | Default is 8. 668 | 669 | n_outputs: int, optional 670 | The number of profile outputs from the model. Generally either 1 or 2 671 | depending on if the data is unstranded or stranded. Default is 2. 672 | 673 | n_control_tracks: int, optional 674 | The number of control tracks to feed into the model. When predicting 675 | TFs, this is usually 2. When predicting accessibility, this is usualy 676 | 0. When 0, this input is removed from the model. Default is 2. 677 | 678 | alpha: float, optional 679 | The weight to put on the count loss. 680 | 681 | profile_output_bias: bool, optional 682 | Whether to include a bias term in the final profile convolution. 683 | Removing this term can help with attribution stability and will usually 684 | not affect performance. Default is True. 685 | 686 | count_output_bias: bool, optional 687 | Whether to include a bias term in the linear layer used to predict 688 | counts. Removing this term can help with attribution stability but 689 | may affect performance. Default is True. 690 | 691 | name: str or None, optional 692 | The name to save the model to during training. 693 | 694 | trimming: int or None, optional 695 | The amount to trim from both sides of the input window to get the 696 | output window. This value is removed from both sides, so the total 697 | number of positions removed is 2*trimming. 698 | 699 | verbose: bool, optional 700 | Whether to display statistics during training. Setting this to False 701 | will still save the file at the end, but does not print anything to 702 | screen during training. Default is True. 703 | """ 704 | 705 | def __init__(self, n_filters=64, n_layers=8, n_outputs=2, 706 | n_control_tracks=2, alpha=1, profile_output_bias=True, 707 | count_output_bias=True, name=None, trimming=None, verbose=True): 708 | super(BasePairNet, self).__init__() 709 | self.n_filters = n_filters 710 | self.n_layers = n_layers 711 | self.n_outputs = n_outputs 712 | self.n_control_tracks = n_control_tracks 713 | 714 | self.alpha = alpha 715 | self.name = name or "bpnet.{}.{}".format(n_filters, n_layers) 716 | self.trimming = trimming or 2 ** n_layers 717 | 718 | self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10) 719 | self.irelu = torch.nn.ReLU() 720 | 721 | self.rconvs = torch.nn.ModuleList([ 722 | torch.nn.Conv1d(n_filters, n_filters, kernel_size=3, padding=2**i, 723 | dilation=2**i) for i in range(1, self.n_layers+1) 724 | ]) 725 | self.rrelus = torch.nn.ModuleList([ 726 | torch.nn.ReLU() for i in range(1, self.n_layers+1) 727 | ]) 728 | 729 | self.fconv = torch.nn.Conv1d(n_filters+n_control_tracks, n_outputs, 730 | kernel_size=75, padding=37, bias=profile_output_bias) 731 | 732 | n_count_control = 1 if n_control_tracks > 0 else 0 733 | self.linear = torch.nn.Linear(n_filters+n_count_control, 1, 734 | bias=count_output_bias) 735 | 736 | self.logger = Logger(["Epoch", "Iteration", "Training Time", 737 | "Validation Time", "Training MNLL", "Training Count MSE", 738 | "Validation MNLL", "Validation Profile Pearson", 739 | "Validation Count Pearson", "Validation Count MSE", "Saved?"], 740 | verbose=verbose) 741 | 742 | 743 | def forward(self, X, X_ctl=None): 744 | """A forward pass of the model. 745 | 746 | This method takes in a nucleotide sequence X, a corresponding 747 | per-position value from a control track, and a per-locus value 748 | from the control track and makes predictions for the profile 749 | and for the counts. This per-locus value is usually the 750 | log(sum(X_ctl_profile)+1) when the control is an experimental 751 | read track but can also be the output from another model. 752 | 753 | Parameters 754 | ---------- 755 | X: torch.tensor, shape=(batch_size, 4, length) 756 | The one-hot encoded batch of sequences. 757 | 758 | X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length) 759 | A value representing the signal of the control at each position in 760 | the sequence. If no controls, pass in None. Default is None. 761 | 762 | Returns 763 | ------- 764 | y_profile: torch.tensor, shape=(batch_size, n_strands, out_length) 765 | The output predictions for each strand trimmed to the output 766 | length. 767 | """ 768 | 769 | start, end = self.trimming, X.shape[2] - self.trimming 770 | 771 | X = self.iconv(X) 772 | for i in range(self.n_layers): 773 | X_a = self.rrelus[i](X) 774 | X_conv = self.rconvs[i](X_a) 775 | X = torch.add(X, X_conv) 776 | X = self.irelu(X) 777 | 778 | if X_ctl is None: 779 | X_w_ctl = X 780 | else: 781 | X_w_ctl = torch.cat([X, X_ctl], dim=1) 782 | 783 | y_profile = self.fconv(X_w_ctl)[:, :, start:end] 784 | 785 | # counts prediction 786 | X = torch.mean(X[:, :, start-37:end+37], dim=2) 787 | if X_ctl is not None: 788 | X_ctl = torch.sum(X_ctl[:, :, start:end], dim=(1, 2)) 789 | X_ctl = X_ctl.unsqueeze(-1) 790 | X = torch.cat([X, torch.log(X_ctl+2)], dim=-1) 791 | 792 | y_counts = self.linear(X).reshape(X.shape[0], 1) 793 | return y_profile, y_counts 794 | 795 | 796 | @classmethod 797 | def from_bpnet(cls, filename): 798 | """Loads a model from BPNet TensorFlow format. 799 | 800 | This method will allow you to load a BPNet model from the basepairmodels 801 | repo that has been saved in TensorFlow format. You do not need to have 802 | TensorFlow installed to use this function. The result will be a model 803 | whose predictions and attributions are identical to those produced when 804 | using the TensorFlow code. 805 | 806 | 807 | Parameters 808 | ---------- 809 | filename: str 810 | The name of the h5 file that stores the trained model parameters. 811 | 812 | 813 | Returns 814 | ------- 815 | model: BPNet 816 | A BPNet model compatible with this repository in PyTorch. 817 | """ 818 | 819 | h5 = h5py.File(filename, "r") 820 | w, k, b = h5['model_weights'], 'kernel:0', 'bias:0' 821 | 822 | extract = lambda name, suffix: w['{0}/{0}/{1}'.format(name, suffix)][:] 823 | convert_w = lambda x: torch.nn.Parameter(torch.tensor(x).permute(2, 1, 824 | 0)) 825 | convert_b = lambda x: torch.nn.Parameter(torch.tensor(x)) 826 | 827 | n_layers, n_filters = 0, extract("main_conv_0", k).shape[2] 828 | for layer_name in w.keys(): 829 | if 'main_dil_conv' in layer_name: 830 | n_layers = max(n_layers, int(layer_name.split("_")[-1])) 831 | 832 | model = cls(n_layers=n_layers, n_filters=n_filters, n_outputs=2, 833 | n_control_tracks=2, trimming=(2114-1000)//2) 834 | 835 | model.iconv.weight = convert_w(extract("main_conv_0", k)) 836 | model.iconv.bias = convert_b(extract("main_conv_0", b)) 837 | model.iconv.padding = ((model.iconv.weight.shape[-1] - 1) // 2,) 838 | 839 | for i in range(1, n_layers+1): 840 | lname = "main_dil_conv_{}".format(i) 841 | model.rconvs[i-1].weight = convert_w(extract(lname, k)) 842 | model.rconvs[i-1].bias = convert_b(extract(lname, b)) 843 | 844 | w0 = model.fconv.weight.numpy(force=True) 845 | wph = extract("main_profile_head", k) 846 | wpp = extract("profile_predictions", k)[0, :2] 847 | 848 | conv_weight = numpy.zeros_like(w0.transpose(2, 1, 0)) 849 | conv_weight[:, :n_filters] = wph.dot(wpp) 850 | conv_weight[37, n_filters:] = extract("profile_predictions", k)[0, 2:] 851 | model.fconv.weight = convert_w(conv_weight) 852 | model.fconv.bias = (convert_b(extract("main_profile_head", b) + 853 | extract("profile_predictions", b))) 854 | model.fconv.padding = ((model.fconv.weight.shape[-1] - 1) // 2,) 855 | 856 | linear_weight = numpy.zeros_like(model.linear.weight.numpy(force=True)) 857 | linear_weight[:, :n_filters] = (extract("main_counts_head", k).T * 858 | extract("logcounts_predictions", k)[0]) 859 | linear_weight[:, -1] = extract("logcounts_predictions", k)[1] 860 | 861 | model.linear.weight = convert_b(linear_weight) 862 | model.linear.bias = (convert_b(extract("main_counts_head", b) * 863 | extract("logcounts_predictions", k)[0] + 864 | extract("logcounts_predictions", b))) 865 | return model 866 | -------------------------------------------------------------------------------- /bpnetlite/chrombpnet.py: -------------------------------------------------------------------------------- 1 | # chrombpnet.py 2 | # Author: Jacob Schreiber 3 | 4 | import h5py 5 | import time 6 | import numpy 7 | import torch 8 | 9 | from .bpnet import BPNet 10 | from .losses import MNLLLoss 11 | from .losses import log1pMSELoss 12 | from .performance import calculate_performance_measures 13 | from .logging import Logger 14 | 15 | from tqdm import trange 16 | 17 | from tangermeme.predict import predict 18 | 19 | 20 | class _Exp(torch.nn.Module): 21 | def __init__(self): 22 | super(_Exp, self).__init__() 23 | 24 | def forward(self, X): 25 | return torch.exp(X) 26 | 27 | 28 | class _Log(torch.nn.Module): 29 | def __init__(self): 30 | super(_Log, self).__init__() 31 | 32 | def forward(self, X): 33 | return torch.log(X) 34 | 35 | 36 | class ChromBPNet(torch.nn.Module): 37 | """A ChromBPNet model. 38 | 39 | ChromBPNet is an extension of BPNet to handle chromatin accessibility data, 40 | in contrast to the protein binding data that BPNet handles. The distinction 41 | between these data types is that an enzyme used in DNase-seq and ATAC-seq 42 | experiments itself has a soft sequence preference, meaning that the 43 | strength of the signal is driven by real biology but that the exact read 44 | mapping locations are driven by the soft sequence bias of the enzyme. 45 | 46 | ChromBPNet handles this by treating the data using two models: a bias 47 | model that is initially trained on background (non-peak) regions where 48 | the bias dominates, and an accessibility model that is subsequently trained 49 | using a frozen version of the bias model. The bias model learns to remove 50 | the enzyme bias so that the accessibility model can learn real motifs. 51 | 52 | 53 | Parameters 54 | ---------- 55 | bias: torch.nn.Module 56 | This model takes in sequence and outputs the shape one would expect in 57 | ATAC-seq data due to Tn5 bias alone. This is usually a BPNet model 58 | from the bpnet-lite repo that has been trained on GC-matched non-peak 59 | regions. 60 | 61 | accessibility: torch.nn.Module 62 | This model takes in sequence and outputs the accessibility one would 63 | expect due to the components of the sequence, but also takes in a cell 64 | representation which modifies the parameters of the model, hence, 65 | "dynamic." This model is usually a DynamicBPNet model, defined below. 66 | 67 | name: str 68 | The name to prepend when saving the file. 69 | """ 70 | 71 | def __init__(self, bias, accessibility, name): 72 | super(ChromBPNet, self).__init__() 73 | for parameter in bias.parameters(): 74 | parameter.requires_grad = False 75 | 76 | self.bias = bias 77 | self.accessibility = accessibility 78 | self.name = name 79 | self.logger = None 80 | self.n_control_tracks = accessibility.n_control_tracks 81 | self.n_outputs = 1 82 | self._log = _Log() 83 | self._exp1 = _Exp() 84 | self._exp2 = _Exp() 85 | 86 | 87 | def forward(self, X, X_ctl=None): 88 | """A forward pass through the network. 89 | 90 | This function is usually accessed through calling the model, e.g. 91 | doing `model(x)`. The method defines how inputs are transformed into 92 | the outputs through interactions with each of the layers. 93 | 94 | 95 | Parameters 96 | ---------- 97 | X: torch.tensor, shape=(-1, 4, 2114) 98 | A one-hot encoded sequence tensor. 99 | 100 | X_ctl: ignore 101 | An ignored parameter for consistency with attribution functions. 102 | 103 | 104 | Returns 105 | ------- 106 | y_profile: torch.tensor, shape=(-1, 1000) 107 | The predicted logit profile for each example. Note that this is not 108 | a normalized value. 109 | """ 110 | 111 | acc_profile, acc_counts = self.accessibility(X) 112 | bias_profile, bias_counts = self.bias(X) 113 | 114 | y_profile = acc_profile + bias_profile 115 | y_counts = self._log(self._exp1(acc_counts) + self._exp2(bias_counts)) 116 | 117 | return y_profile, y_counts 118 | 119 | 120 | def fit(self, training_data, optimizer, X_valid=None, y_valid=None, 121 | max_epochs=100, batch_size=64, validation_iter=100, early_stopping=None, 122 | verbose=True): 123 | """Fit the ChromBPNet model to data. 124 | 125 | Specifically, this function will fit the accessibility model to 126 | observed chromatin accessibility data, and assume that the bias model 127 | is frozen and pre-trained. Hence, the only parameters being trained 128 | in this function are those in the accessibility model. 129 | 130 | This function will save the best full ChromBPNet model, as well as the 131 | best accessibility model, found during training. 132 | 133 | 134 | Parameter 135 | --------- 136 | training_data: torch.utils.data.DataLoader 137 | A data set that generates one-hot encoded sequence as input and 138 | read count signal for the output. 139 | 140 | optimizer: torch.optim.Optimizer 141 | A PyTorch optimizer. 142 | 143 | X_valid: torch.Tensor or None, shape=(-1, 4, length) 144 | A tensor of one-hot encoded sequences to use as input for the 145 | validation steps. If None, do not do validation. Default is None. 146 | 147 | y_valid: torch.Tensor or None, shape=(-1, 1, length) 148 | A tensor of read counts matched with the `X_valid` input. If None, 149 | do not do validation. Default is None. 150 | 151 | max_epochs: int 152 | The maximum number of training epochs to perform before stopping. 153 | Default is 100. 154 | 155 | batch_size: int 156 | The number of examples to use in each batch. Default is 64. 157 | 158 | dtype: str, optional 159 | Whether to use mixed precision and, if so, what dtype to use. If not 160 | using 'float32', recommended is to use 'bfloat16'. Default is 'float32'. 161 | 162 | validation_iter: int 163 | The number of training batches to perform before doing another 164 | round of validation. Set higher to spend a higher percentage of 165 | time in the training step. 166 | 167 | early_stopping: int or None 168 | Whether to stop training early. If None, continue training until 169 | max_epochs is reached. If an integer, continue training until that 170 | number of `validation_iter` ticks has been hit without improvement 171 | in performance. Default is None. 172 | 173 | verbose: bool 174 | Whether to print the log as it is being generated. A log will 175 | be returned at the end of training regardless of this option, but 176 | when False, nothing will be printed to the screen during training. 177 | Default is False 178 | """ 179 | 180 | dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype 181 | 182 | y_bias_profile, y_bias_counts = predict(self.bias, X_valid, 183 | batch_size=batch_size, dtype=dtype, device='cuda') 184 | 185 | self.logger = Logger(["Epoch", "Iteration", "Training Time", 186 | "Validation Time", "Training MNLL", "Training Count MSE", 187 | "Validation MNLL", "Validation Profile Correlation", 188 | "Validation Count Pearson", "Validation Count MSE", "Saved?"], 189 | verbose=verbose) 190 | 191 | early_stop_count = 0 192 | start, best_loss = time.time(), float("inf") 193 | 194 | self.logger.start() 195 | self.bias.eval() 196 | for epoch in range(max_epochs): 197 | for iteration, (X, y) in enumerate(training_data): 198 | self.accessibility.train() 199 | 200 | X = X.cuda() 201 | y = y.cuda() 202 | 203 | optimizer.zero_grad() 204 | 205 | acc_profile, acc_counts = self.accessibility(X) 206 | bias_profile, bias_counts = self.bias(X) 207 | 208 | y_profile = torch.nn.functional.log_softmax(acc_profile + 209 | bias_profile, dim=-1) 210 | 211 | y_counts = torch.logsumexp(torch.stack([acc_counts, 212 | bias_counts]), dim=0) 213 | 214 | profile_loss = MNLLLoss(y_profile, y).mean() 215 | count_loss = log1pMSELoss(y_counts, y.sum(dim=-1).reshape(-1, 216 | 1)).mean() 217 | 218 | profile_loss_ = profile_loss.item() 219 | count_loss_ = count_loss.item() 220 | 221 | loss = profile_loss + self.accessibility.alpha * count_loss 222 | loss.backward() 223 | optimizer.step() 224 | 225 | if verbose and iteration % validation_iter == 0: 226 | train_time = time.time() - start 227 | tic = time.time() 228 | 229 | with torch.no_grad(): 230 | self.accessibility.eval() 231 | 232 | y_profile, y_counts = predict(self.accessibility, 233 | X_valid, batch_size=batch_size, device='cuda') 234 | 235 | y_profile = torch.nn.functional.log_softmax( 236 | y_profile + y_bias_profile, dim=-1) 237 | 238 | y_counts = torch.logsumexp(torch.stack([y_counts, 239 | y_bias_counts]), dim=0) 240 | 241 | measures = calculate_performance_measures(y_profile, 242 | y_valid, y_counts, kernel_sigma=7, 243 | kernel_width=81, measures=['profile_mnll', 244 | 'profile_pearson', 'count_pearson', 'count_mse']) 245 | 246 | profile_corr = measures['profile_pearson'] 247 | count_corr = measures['count_pearson'] 248 | 249 | valid_loss = measures['profile_mnll'].mean() 250 | valid_loss += self.accessibility.alpha * measures['count_mse'].mean() 251 | valid_time = time.time() - tic 252 | 253 | self.logger.add([epoch, iteration, train_time, 254 | valid_time, profile_loss_, count_loss_, 255 | measures['profile_mnll'].mean().item(), 256 | numpy.nan_to_num(profile_corr).mean(), 257 | numpy.nan_to_num(count_corr).mean(), 258 | measures['count_mse'].mean().item(), 259 | (valid_loss < best_loss).item()]) 260 | 261 | if valid_loss < best_loss: 262 | torch.save(self, "{}.torch".format(self.name)) 263 | torch.save(self.accessibility, 264 | "{}.accessibility.torch".format(self.name)) 265 | best_loss = valid_loss 266 | early_stop_count = 0 267 | else: 268 | early_stop_count += 1 269 | 270 | start = time.time() 271 | 272 | if early_stopping is not None and early_stop_count >= early_stopping: 273 | break 274 | 275 | self.logger.save("{}.log".format(self.name)) 276 | 277 | if early_stopping is not None and early_stop_count >= early_stopping: 278 | break 279 | 280 | torch.save(self, "{}.final.torch".format(self.name)) 281 | torch.save(self, "{}.accessibility.final.torch".format(self.name)) 282 | 283 | 284 | @classmethod 285 | def from_chrombpnet_lite(self, bias_model, accessibility_model, name): 286 | """Load a ChromBPNet model trained in ChromBPNet-lite. 287 | 288 | Confusingly, ChromBPNet-lite is a package written by Surag Nair that 289 | reorganized the ChromBPNet library and then was reintegrated back 290 | into it. However, some ChromBPNet models are still around that were 291 | trained using this package and this is a method for loading those 292 | models, not the models trained using the ChromBPNet package and not 293 | ChromBPNet models trained using this package. 294 | 295 | This method takes in paths to a h5 file containing the weights of the 296 | bias model and the accessibility model, both trained and whose outputs 297 | are organized according to TensorFlow. The weights are loaded and 298 | shaped into a PyTorch model and can be used as such. 299 | 300 | 301 | Parameters 302 | ---------- 303 | bias model: str 304 | The filename of the bias model. 305 | 306 | accessibility_model: str 307 | The filename of the accessibility model. 308 | 309 | name: str 310 | The name to use when training the model and outputting to a file. 311 | 312 | 313 | Returns 314 | ------- 315 | model: bpnetlite.models.ChromBPNet 316 | A PyTorch ChromBPNet model compatible with the bpnet-lite package. 317 | """ 318 | 319 | bias = BPNet.from_chrombpnet_lite(bias_model) 320 | acc = BPNet.from_chrombpnet_lite(accessibility_model) 321 | return ChromBPNet(bias, acc, name) 322 | 323 | 324 | @classmethod 325 | def from_chrombpnet(self, bias_model, accessibility_model, name): 326 | """Load a ChromBPNet model trained using the official repository. 327 | 328 | This method takes in the path to a .h5 file containing the full model, 329 | i.e., the bias model AND the accessibility model. If you have two 330 | files -- one for the bias model, and one for the accessibility model -- 331 | load them up as separate BPNet models and create a ChromBPNet object 332 | afterwards. 333 | 334 | 335 | Parameters 336 | ---------- 337 | bias model: str 338 | The filename of the bias model. 339 | 340 | accessibility_model: str 341 | The filename of the accessibility model. 342 | 343 | name: str 344 | The name to use when training the model and outputting to a file. 345 | 346 | 347 | Returns 348 | ------- 349 | model: bpnetlite.models.ChromBPNet 350 | A PyTorch ChromBPNet model compatible with the bpnet-lite package. 351 | """ 352 | 353 | 354 | bias = BPNet.from_chrombpnet(bias_model) 355 | acc = BPNet.from_chrombpnet(accessibility_model) 356 | return ChromBPNet(bias, acc, name) -------------------------------------------------------------------------------- /bpnetlite/io.py: -------------------------------------------------------------------------------- 1 | # io.py 2 | # Author: Jacob Schreiber 3 | # Code adapted from Alex Tseng, Avanti Shrikumar, and Ziga Avsec 4 | 5 | import numpy 6 | import torch 7 | 8 | from tqdm import tqdm 9 | from tangermeme.io import extract_loci 10 | 11 | 12 | class DataGenerator(torch.utils.data.Dataset): 13 | """A data generator for BPNet inputs. 14 | 15 | This generator takes in an extracted set of sequences, output signals, 16 | and control signals, and will return a single element with random 17 | jitter and reverse-complement augmentation applied. Jitter is implemented 18 | efficiently by taking in data that is wider than the in/out windows by 19 | two times the maximum jitter and windows are extracted from that. 20 | Essentially, if an input window is 1000 and the maximum jitter is 128, one 21 | would pass in data with a length of 1256 and a length 1000 window would be 22 | extracted starting between position 0 and 256. This generator must be 23 | wrapped by a PyTorch generator object. 24 | 25 | Parameters 26 | ---------- 27 | sequences: torch.tensor, shape=(n, 4, in_window+2*max_jitter) 28 | A one-hot encoded tensor of `n` example sequences, each of input 29 | length `in_window`. See description above for connection with jitter. 30 | 31 | signals: torch.tensor, shape=(n, t, out_window+2*max_jitter) 32 | The signals to predict, usually counts, for `n` examples with 33 | `t` output tasks (usually 2 if stranded, 1 otherwise), each of 34 | output length `out_window`. See description above for connection 35 | with jitter. 36 | 37 | controls: torch.tensor, shape=(n, t, out_window+2*max_jitter) or None, optional 38 | The control signal to take as input, usually counts, for `n` 39 | examples with `t` strands and output length `out_window`. If 40 | None, does not return controls. 41 | 42 | in_window: int, optional 43 | The input window size. Default is 2114. 44 | 45 | out_window: int, optional 46 | The output window size. Default is 1000. 47 | 48 | max_jitter: int, optional 49 | The maximum amount of jitter to add, in either direction, to the 50 | midpoints that are passed in. Default is 0. 51 | 52 | reverse_complement: bool, optional 53 | Whether to reverse complement-augment half of the data. Default is False. 54 | 55 | random_state: int or None, optional 56 | Whether to use a deterministic seed or not. 57 | """ 58 | 59 | def __init__(self, sequences, signals, controls=None, in_window=2114, 60 | out_window=1000, max_jitter=0, reverse_complement=False, 61 | random_state=None): 62 | self.in_window = in_window 63 | self.out_window = out_window 64 | self.max_jitter = max_jitter 65 | 66 | self.reverse_complement = reverse_complement 67 | self.random_state = numpy.random.RandomState(random_state) 68 | 69 | self.signals = signals 70 | self.controls = controls 71 | self.sequences = sequences 72 | 73 | def __len__(self): 74 | return len(self.sequences) 75 | 76 | def __getitem__(self, idx): 77 | i = self.random_state.choice(len(self.sequences)) 78 | j = 0 if self.max_jitter == 0 else self.random_state.randint( 79 | self.max_jitter*2) 80 | 81 | X = self.sequences[i][:, j:j+self.in_window] 82 | y = self.signals[i][:, j:j+self.out_window] 83 | 84 | if self.controls is not None: 85 | X_ctl = self.controls[i][:, j:j+self.in_window] 86 | 87 | if self.reverse_complement and self.random_state.choice(2) == 1: 88 | X = torch.flip(X, [0, 1]) 89 | y = torch.flip(y, [0, 1]) 90 | 91 | if self.controls is not None: 92 | X_ctl = torch.flip(X_ctl, [0, 1]) 93 | 94 | if self.controls is not None: 95 | return X, X_ctl, y 96 | 97 | return X, y 98 | 99 | 100 | def PeakGenerator(loci, sequences, signals, controls=None, chroms=None, 101 | in_window=2114, out_window=1000, max_jitter=128, reverse_complement=True, 102 | min_counts=None, max_counts=None, random_state=None, pin_memory=True, 103 | num_workers=0, batch_size=32, verbose=False): 104 | """This is a constructor function that handles all IO. 105 | 106 | This function will extract signal from all signal and control files, 107 | pass that into a DataGenerator, and wrap that using a PyTorch data 108 | loader. This is the only function that needs to be used. 109 | 110 | Parameters 111 | ---------- 112 | loci: str or pandas.DataFrame or list/tuple of such 113 | Either the path to a bed file or a pandas DataFrame object containing 114 | three columns: the chromosome, the start, and the end, of each locus 115 | to train on. Alternatively, a list or tuple of strings/DataFrames where 116 | the intention is to train on the interleaved concatenation, i.e., when 117 | you want ot train on peaks and negatives. 118 | 119 | sequences: str or dictionary 120 | Either the path to a fasta file to read from or a dictionary where the 121 | keys are the unique set of chromosoms and the values are one-hot 122 | encoded sequences as numpy arrays or memory maps. 123 | 124 | signals: list of strs or list of dictionaries 125 | A list of filepaths to bigwig files, where each filepath will be read 126 | using pyBigWig, or a list of dictionaries where the keys are the same 127 | set of unique chromosomes and the values are numpy arrays or memory 128 | maps. 129 | 130 | controls: list of strs or list of dictionaries or None, optional 131 | A list of filepaths to bigwig files, where each filepath will be read 132 | using pyBigWig, or a list of dictionaries where the keys are the same 133 | set of unique chromosomes and the values are numpy arrays or memory 134 | maps. If None, no control tensor is returned. Default is None. 135 | 136 | chroms: list or None, optional 137 | A set of chromosomes to extact loci from. Loci in other chromosomes 138 | in the locus file are ignored. If None, all loci are used. Default is 139 | None. 140 | 141 | in_window: int, optional 142 | The input window size. Default is 2114. 143 | 144 | out_window: int, optional 145 | The output window size. Default is 1000. 146 | 147 | max_jitter: int, optional 148 | The maximum amount of jitter to add, in either direction, to the 149 | midpoints that are passed in. Default is 128. 150 | 151 | reverse_complement: bool, optional 152 | Whether to reverse complement-augment half of the data. Default is True. 153 | 154 | min_counts: float or None, optional 155 | The minimum number of counts, summed across the length of each example 156 | and across all tasks, needed to be kept. If None, no minimum. Default 157 | is None. 158 | 159 | max_counts: float or None, optional 160 | The maximum number of counts, summed across the length of each example 161 | and across all tasks, needed to be kept. If None, no maximum. Default 162 | is None. 163 | 164 | random_state: int or None, optional 165 | Whether to use a deterministic seed or not. 166 | 167 | pin_memory: bool, optional 168 | Whether to pin page memory to make data loading onto a GPU easier. 169 | Default is True. 170 | 171 | num_workers: int, optional 172 | The number of processes fetching data at a time to feed into a model. 173 | If 0, data is fetched from the main process. Default is 0. 174 | 175 | batch_size: int, optional 176 | The number of data elements per batch. Default is 32. 177 | 178 | verbose: bool, optional 179 | Whether to display a progress bar while loading. Default is False. 180 | 181 | Returns 182 | ------- 183 | X: torch.utils.data.DataLoader 184 | A PyTorch DataLoader wrapped DataGenerator object. 185 | """ 186 | 187 | X = extract_loci(loci=loci, sequences=sequences, signals=signals, 188 | in_signals=controls, chroms=chroms, in_window=in_window, 189 | out_window=out_window, max_jitter=max_jitter, min_counts=min_counts, 190 | max_counts=max_counts, ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 191 | verbose=verbose) 192 | 193 | if controls is not None: 194 | sequences, signals_, controls_ = X 195 | else: 196 | sequences, signals_ = X 197 | controls_ = None 198 | 199 | #sequences = sequences.float() 200 | 201 | X_gen = DataGenerator(sequences, signals_, controls=controls_, 202 | in_window=in_window, out_window=out_window, max_jitter=max_jitter, 203 | reverse_complement=reverse_complement, random_state=random_state) 204 | 205 | X_gen = torch.utils.data.DataLoader(X_gen, pin_memory=pin_memory, 206 | num_workers=num_workers, batch_size=batch_size) 207 | 208 | return X_gen 209 | -------------------------------------------------------------------------------- /bpnetlite/logging.py: -------------------------------------------------------------------------------- 1 | # logging.py 2 | # Author: Jacob Schreiber 3 | 4 | import pandas 5 | 6 | class Logger(): 7 | """A logging class that can report or save metrics. 8 | 9 | This class contains a simple utility for saving statistics as they are 10 | generated, saving a report to a text file at the end, and optionally 11 | print the report to screen one line at a time as it is being generated. 12 | Must begin using the `start` method, which will reset the logger. 13 | 14 | 15 | Parameters 16 | ---------- 17 | names: list or tuple 18 | An iterable containing the names of the columns to be logged. 19 | 20 | verbose: bool, optional 21 | Whether to print to screen during the logging. 22 | """ 23 | 24 | def __init__(self, names, verbose=False): 25 | self.names = names 26 | self.verbose = verbose 27 | 28 | def start(self): 29 | """Begin the recording process.""" 30 | 31 | self.data = {name: [] for name in self.names} 32 | 33 | if self.verbose: 34 | print("\t".join(self.names)) 35 | 36 | def add(self, row): 37 | """Add a row to the log. 38 | 39 | This method will add one row to the log and, if verbosity is set, 40 | will print out the row to the log. The row must be the same length 41 | as the names given at instantiation. 42 | 43 | 44 | Parameters 45 | ---------- 46 | args: tuple or list 47 | An iterable containing the statistics to be saved. 48 | """ 49 | 50 | assert len(row) == len(self.names) 51 | 52 | for name, value in zip(self.names, row): 53 | self.data[name].append(value) 54 | 55 | if self.verbose: 56 | print("\t".join(map(str, [round(x, 4) if isinstance(x, float) else x 57 | for x in row]))) 58 | 59 | def save(self, name): 60 | """Write a log to disk. 61 | 62 | 63 | Parameters 64 | ---------- 65 | name: str 66 | The filename to save the logs to. 67 | """ 68 | 69 | pandas.DataFrame(self.data).to_csv(name, sep='\t', index=False) -------------------------------------------------------------------------------- /bpnetlite/losses.py: -------------------------------------------------------------------------------- 1 | # losses.py 2 | # Authors: Jacob Schreiber 3 | 4 | """ 5 | This module contains the losses used by BPNet for training. 6 | """ 7 | 8 | import torch 9 | 10 | def MNLLLoss(logps, true_counts): 11 | """A loss function based on the multinomial negative log-likelihood. 12 | 13 | This loss function takes in a tensor of normalized log probabilities such 14 | that the sum of each row is equal to 1 (e.g. from a log softmax) and 15 | an equal sized tensor of true counts and returns the probability of 16 | observing the true counts given the predicted probabilities under a 17 | multinomial distribution. Can accept tensors with 2 or more dimensions 18 | and averages over all except for the last axis, which is the number 19 | of categories. 20 | 21 | Adapted from Alex Tseng. 22 | 23 | Parameters 24 | ---------- 25 | logps: torch.tensor, shape=(n, ..., L) 26 | A tensor with `n` examples and `L` possible categories. 27 | 28 | true_counts: torch.tensor, shape=(n, ..., L) 29 | A tensor with `n` examples and `L` possible categories. 30 | 31 | Returns 32 | ------- 33 | loss: float 34 | The multinomial log likelihood loss of the true counts given the 35 | predicted probabilities, averaged over all examples and all other 36 | dimensions. 37 | """ 38 | 39 | log_fact_sum = torch.lgamma(torch.sum(true_counts, dim=-1) + 1) 40 | log_prod_fact = torch.sum(torch.lgamma(true_counts + 1), dim=-1) 41 | log_prod_exp = torch.sum(true_counts * logps, dim=-1) 42 | return -log_fact_sum + log_prod_fact - log_prod_exp 43 | 44 | 45 | def log1pMSELoss(log_predicted_counts, true_counts): 46 | """A MSE loss on the log(x+1) of the inputs. 47 | 48 | This loss will accept tensors of predicted counts and a vector of true 49 | counts and return the MSE on the log of the labels. The squared error 50 | is calculated for each position in the tensor and then averaged, regardless 51 | of the shape. 52 | 53 | Note: The predicted counts are in log space but the true counts are in the 54 | original count space. 55 | 56 | Parameters 57 | ---------- 58 | log_predicted_counts: torch.tensor, shape=(n, ...) 59 | A tensor of log predicted counts where the first axis is the number of 60 | examples. Important: these values are already in log space. 61 | 62 | true_counts: torch.tensor, shape=(n, ...) 63 | A tensor of the true counts where the first axis is the number of 64 | examples. 65 | 66 | Returns 67 | ------- 68 | loss: torch.tensor, shape=(n, 1) 69 | The MSE loss on the log of the two inputs, averaged over all examples 70 | and all other dimensions. 71 | """ 72 | 73 | log_true = torch.log(true_counts+1) 74 | return torch.mean(torch.square(log_true - log_predicted_counts), dim=-1) 75 | -------------------------------------------------------------------------------- /bpnetlite/marginalize.py: -------------------------------------------------------------------------------- 1 | # marginalize.py 2 | # Author: Jacob Schreiber 3 | 4 | import os 5 | import numpy 6 | import torch 7 | import pandas 8 | import seaborn 9 | 10 | from bpnetlite.bpnet import _ProfileLogitScaling 11 | from bpnetlite.bpnet import CountWrapper 12 | from bpnetlite.bpnet import ProfileWrapper 13 | from bpnetlite.chrombpnet import _Log, _Exp 14 | 15 | from tangermeme.io import read_meme 16 | from tangermeme.io import one_hot_encode 17 | 18 | from tangermeme.marginalize import marginalize 19 | from tangermeme.deep_lift_shap import deep_lift_shap 20 | from tangermeme.deep_lift_shap import _nonlinear 21 | 22 | from tangermeme.plot import plot_logo 23 | 24 | from tqdm import tqdm 25 | 26 | import matplotlib.pyplot as plt 27 | 28 | 29 | def path_to_image_html(path): 30 | return '' 31 | 32 | 33 | def _plot_profiles(y, ylim, color, path, figsize=(10,3)): 34 | plt.figure(figsize=figsize) 35 | plt.plot(y, color=color) 36 | 37 | seaborn.despine() 38 | plt.xlim(0, y.shape[0]) 39 | plt.xlabel("Relative Coordinate", fontsize=12) 40 | plt.ylabel("Predicted Signal", fontsize=12) 41 | plt.ylim(*ylim) 42 | plt.yticks(fontsize=12) 43 | plt.savefig(path) 44 | plt.close() 45 | 46 | 47 | def _plot_counts(y_before, y_after, xlim, ylim, color, path, figsize=(10,3)): 48 | zmax = max(xlim[1], ylim[1]) 49 | zmin = min(xlim[0], ylim[0]) 50 | 51 | plt.figure(figsize=figsize) 52 | plt.plot([zmin, zmax], [zmin, zmax], color='0.5') 53 | plt.scatter(y_before, y_after, color=color, s=5) 54 | 55 | seaborn.despine() 56 | plt.xlabel("Pred Counts Before", fontsize=12) 57 | plt.ylabel("Pred Counts After", fontsize=12) 58 | plt.xlim(*xlim) 59 | plt.ylim(*ylim) 60 | plt.yticks(fontsize=12) 61 | plt.savefig(path) 62 | plt.close() 63 | 64 | 65 | def _plot_attributions(y, ylim, path, figsize=(10,3), **kwargs): 66 | fig = plt.figure(figsize=figsize) 67 | ax = fig.add_subplot(111) 68 | 69 | plot_logo(y.T, ax=ax) 70 | 71 | plt.yticks(fontsize=12) 72 | plt.ylim(*ylim) 73 | plt.savefig(path) 74 | plt.close() 75 | 76 | 77 | def marginalization_report(model, motifs, X, output_dir, batch_size=64, 78 | attributions=True, minimal=False, verbose=False): 79 | """Create an HTML report showing the impact of each motif. 80 | 81 | Take in a predictive model, a MEME file of motifs, and a set of sequences, 82 | and run a marginalization experiment on each motif. Store the images 83 | to the output directory, and then create an HTML report that puts together 84 | these images. 85 | 86 | 87 | Parameters 88 | ---------- 89 | model: bpnetlite.bpnet.BPNet or bpnetlite.chrombpnet.ChromBPNet 90 | A BPNet- or ChromBPNet-style model that outputs predictions for a 91 | profile head and a count head. 92 | 93 | motif: str 94 | A motif to insert into the middle of the background sequences. 95 | 96 | X: torch.Tensor, shape=(n, 4, 2114) 97 | A one-hot encoded set of n sequences to run through the model. 98 | 99 | output_dir: str 100 | The folder name to put all the images that are generated. 101 | 102 | batch_size: int, optional 103 | The number of examples to run at a time in each batch. Default is 64. 104 | 105 | attributions: bool, optional 106 | Whether to calculate attributions as well as calculating predictions 107 | when doing marginalizations. Because calculating attributions is the 108 | most time-intensive aspect, setting this parameter to False can save 109 | time. Default is True 110 | 111 | minimal: bool, optional 112 | Whether to produce a minimal report, which shows the differences in 113 | outputs, or the full report, which shows the results before and after 114 | insertion as well as the differences. Potentially useful for debugging. 115 | Default is False. 116 | 117 | verbose: bool, optional 118 | Whether to print a progress bar as motifs are marginalized over. Default 119 | is False. 120 | """ 121 | 122 | motifs = list(read_meme(motifs).items()) 123 | 124 | if not os.path.isdir(output_dir): 125 | os.mkdir(output_dir) 126 | 127 | results = { 128 | 'name': [], 129 | 'sequence': [], 130 | 'profile (before)': [], 131 | 'profile (after)': [], 132 | 'profile (diff)': [], 133 | 'counts': [], 134 | 'profile attributions (before)': [], 135 | 'profile attributions (after)': [], 136 | 'profile attributions (diff)': [], 137 | 'counts attributions (before)': [], 138 | 'counts attributions (after)': [], 139 | 'counts attributions (diff)': [] 140 | } 141 | 142 | 143 | p, c, a = 'profile', 'counts', ' attributions' 144 | pb, pa, pd = '.profile.before', '.profile.after', '.profile.diff' 145 | pab, paa, pad = '.pattr.before', '.pattr.after', '.pattr.diff' 146 | cab, caa, cad = '.cattr.before', '.cattr.after', '.cattr.diff' 147 | 148 | mid, w = X.shape[-1] // 2, 15 149 | s, e = mid - w, mid + w 150 | 151 | p_before, p_after, p_diff = [], [], [] 152 | c_before, c_after, c_diff = [], [], [] 153 | ap_before, ap_after, ap_diff = [], [], [] 154 | ac_before, ac_after, ac_diff = [], [], [] 155 | 156 | for i, (name, pwm) in tqdm(enumerate(motifs), disable=not verbose): 157 | motif = ''.join(numpy.array(['A', 'C', 'G', 'T'])[pwm.argmax(axis=0)]) 158 | 159 | (y_profile_before, y_counts_before), (y_profile_after, 160 | y_counts_after) = marginalize(model, X, motif, 161 | batch_size=batch_size) 162 | 163 | y_profile_before = torch.nn.functional.softmax(y_profile_before, dim=-1) 164 | y_profile_after = torch.nn.functional.softmax(y_profile_after, dim=-1) 165 | 166 | p_before.append(y_profile_before.mean(axis=0).T) 167 | p_after.append(y_profile_after.mean(axis=0).T) 168 | p_diff_ = (y_profile_after - y_profile_before).mean(axis=0).T 169 | p_diff.append(p_diff_) 170 | 171 | c_before.append(y_counts_before) 172 | c_after.append(y_counts_after) 173 | 174 | if attributions: 175 | Xp_attr_before, Xp_attr_after = marginalize(ProfileWrapper( 176 | model.double()), X.double(), motif, func=deep_lift_shap, 177 | additional_nonlinear_ops={ 178 | _ProfileLogitScaling: _nonlinear, 179 | _Log: _nonlinear, 180 | _Exp: _nonlinear 181 | }, 182 | n_shuffles=1, batch_size=batch_size) 183 | 184 | Xc_attr_before, Xc_attr_after = marginalize( 185 | CountWrapper(model).double(), X.double(), motif, 186 | func=deep_lift_shap, n_shuffles=1, 187 | additional_nonlinear_ops={ 188 | _ProfileLogitScaling: _nonlinear, 189 | _Log: _nonlinear, 190 | _Exp: _nonlinear 191 | }, 192 | batch_size=batch_size) 193 | 194 | ap_before.append(Xp_attr_before.mean(axis=0)[:, s:e].T) 195 | ap_after.append(Xp_attr_after.mean(axis=0)[:, s:e].T) 196 | ap_diff_ = (Xp_attr_after - Xp_attr_before).mean(axis=0)[:, s:e].T 197 | ap_diff.append(ap_diff_) 198 | 199 | ac_before.append(Xc_attr_before.mean(axis=0)[:, s:e].T) 200 | ac_after.append(Xc_attr_after.mean(axis=0)[:, s:e].T) 201 | ac_diff_ = (Xc_attr_after - Xc_attr_before).mean(axis=0)[:, s:e].T 202 | ac_diff.append(ac_diff_) 203 | 204 | multistack = lambda args: [torch.stack(a).numpy() for a in args] 205 | p_before, p_after, p_diff = multistack([p_before, p_after, p_diff]) 206 | c_before, c_after = multistack([c_before, c_after]) 207 | 208 | multilims = lambda args: [(x.min() * 0.95, x.max() * 1.05) for x in args] 209 | pb_lim, pa_lim, pd_lim = multilims([p_before, p_after, p_diff]) 210 | cb_lim, ca_lim = multilims([c_before, c_after]) 211 | 212 | if attributions: 213 | apb, apa, apd = multistack([ap_before, ap_after, ap_diff]) 214 | acb, aca, acd = multistack([ac_before, ac_after, ac_diff]) 215 | 216 | apb_lim, apa_lim, apd_lim = multilims([apb, apa, apd]) 217 | acb_lim, aca_lim, acd_lim = multilims([acb, aca, acd]) 218 | 219 | 220 | idxs = (c_after - c_before).mean(axis=1)[:, 0].argsort()[::-1] 221 | for i, idx in enumerate(idxs): 222 | name, pwm = motifs[idx] 223 | oname = output_dir + name 224 | motif = ''.join(numpy.array(['A', 'C', 'G', 'T'])[pwm.argmax(axis=0)]) 225 | 226 | if not minimal: 227 | _plot_profiles(p_before[idx], pb_lim, color='0.5', 228 | path=oname + pb + ".png") 229 | _plot_profiles(p_after[idx], pa_lim, color='0.5', 230 | path=oname + pa + ".png") 231 | 232 | _plot_profiles(p_diff[idx], pd_lim, color='c', 233 | path=oname + pd + ".png") 234 | _plot_counts(c_before[idx], c_after[idx], cb_lim, ca_lim, color='m', 235 | path=oname + ".counts.png") 236 | 237 | if attributions: 238 | if not minimal: 239 | _plot_attributions(apb[idx], apb_lim, oname + pab + ".png") 240 | _plot_attributions(apa[idx], apa_lim, oname + paa + ".png") 241 | 242 | _plot_attributions(acb[idx], acb_lim, oname + cab + ".png") 243 | _plot_attributions(aca[idx], aca_lim, oname + caa + ".png") 244 | 245 | _plot_attributions(apd[idx], apd_lim, oname + pad + ".png") 246 | _plot_attributions(acd[idx], acd_lim, oname + cad + ".png") 247 | 248 | motif_ = motif[:25] + ('...' if len(motif) > 25 else '') 249 | 250 | results['name'].append(name) 251 | results['sequence'].append(motif_) 252 | results[p + ' (before)'].append(oname + pb + ".png") 253 | results[p + ' (after)'].append(oname + pa + ".png") 254 | results[p + ' (diff)'].append(oname + pd + ".png") 255 | results[c].append(oname + ".counts.png") 256 | 257 | if attributions: 258 | results[p + a + ' (before)'].append(oname + pab + ".png") 259 | results[p + a + ' (after)'].append(oname + paa + ".png") 260 | results[p + a + ' (diff)'].append(oname + pad + ".png") 261 | results[c + a + ' (before)'].append(oname + cab + ".png") 262 | results[c + a + ' (after)'].append(oname + caa + ".png") 263 | results[c + a + ' (diff)'].append(oname + cad + ".png") 264 | 265 | if not attributions: 266 | for key in results.keys(): 267 | if 'attributions' in key: 268 | del results[key] 269 | 270 | formatters = {name: path_to_image_html for name in results.keys() 271 | if name not in ('name', 'sequence')} 272 | 273 | results_df = pandas.DataFrame(results) 274 | if minimal: 275 | results_df = results_df[['name', 'sequence', 'profile (diff)', 276 | 'counts', 'profile attributions (diff)', 277 | 'counts attributions (diff)']] 278 | 279 | results_df.to_html(open('{}/marginalization.html'.format(output_dir), 'w'), 280 | escape=False, formatters=formatters, index=False) 281 | -------------------------------------------------------------------------------- /bpnetlite/performance.py: -------------------------------------------------------------------------------- 1 | # performance.py 2 | # Authors: Alex Tseng 3 | # Jacob Schreiber 4 | 5 | """ 6 | This module contains performance measures that are used to evaluate the 7 | model but are not explicitly used as losses for optimization. 8 | 9 | IMPORTANT: MANY OF THESE FUNCTIONS ASSUME THE INPUTS TO BE PREDICTED LOG 10 | PROBABILITIES AND TRUE COUNTS. THE FIRST ARGUMENT MUST BE IN LOG SPACE 11 | AND THE SECOND ARGUMENT MUST BE IN COUNT SPACE FOR THESE FUNCTIONS. 12 | """ 13 | 14 | import torch 15 | 16 | from .losses import MNLLLoss 17 | from .losses import log1pMSELoss 18 | 19 | 20 | def smooth_gaussian1d(x, kernel_sigma, kernel_width): 21 | """Smooth a signal along the sequence length axis. 22 | 23 | This function is a replacement for the scipy.ndimage.gaussian1d 24 | function that works on PyTorch tensors. It applies a Gaussian kernel 25 | to each position which is equivalent to applying a convolution across 26 | the sequence with weights equal to that of a Gaussian distribution. 27 | Each sequence, and each channel within the sequence, is smoothed 28 | independently. 29 | 30 | Parameters 31 | ---------- 32 | x: torch.tensor, shape=(n_sequences, n_channels, seq_len) 33 | A tensor to smooth along the last axis. n_channels must be at 34 | least 1. 35 | 36 | kernel_sigma: float 37 | The standard deviation of the Gaussian to be applied. 38 | 39 | kernel_width: int 40 | The width of the kernel to be applied. 41 | 42 | Returns 43 | ------- 44 | x_smooth: torch.tensor, shape=(n_sequences, n_channels, seq_len) 45 | The smoothed tensor. 46 | """ 47 | 48 | meshgrid = torch.arange(kernel_width, dtype=torch.float32, 49 | device=x.device) 50 | 51 | mean = (kernel_width - 1.) / 2. 52 | kernel = torch.exp(-0.5 * ((meshgrid - mean) / kernel_sigma) ** 2.0) 53 | kernel = kernel / torch.sum(kernel) 54 | kernel = kernel.reshape(1, 1, kernel_width).repeat(x.shape[1], 1, 1) 55 | return torch.nn.functional.conv1d(x, weight=kernel, groups=x.shape[1], 56 | padding='same') 57 | 58 | def batched_smoothed_function(logps, true_counts, f, smooth_predictions=False, 59 | smooth_true=False, kernel_sigma=7, kernel_width=81, 60 | exponentiate_logps=False, batch_size=200): 61 | """Batch a calculation with optional smoothing. 62 | 63 | Given a set of predicted and true values, apply some function to them in 64 | a batched manner and store the results. Optionally, either the true values 65 | or the predicted ones can be smoothed. 66 | 67 | Parameters 68 | ---------- 69 | logps: torch.tensor 70 | A tensor of the predicted log probability values. 71 | 72 | true_counts: torch.tensor 73 | A tensor of the true values, usually integer counts. 74 | 75 | f: function 76 | A function to be applied to the predicted and true values. 77 | 78 | smooth_predictions: bool, optional 79 | Whether to apply a Gaussian filter to the predictions. Default is 80 | False. 81 | 82 | smooth_true: bool, optional 83 | Whether to apply a Gaussian filter to the true values. Default is 84 | False. 85 | 86 | kernel_sigma: float, optional 87 | The standard deviation of the Gaussian to be applied. Default is 7. 88 | 89 | kernel_width: int, optional 90 | The width of the kernel to be applied. Default is 81. 91 | 92 | exponentiate_logps: bool, optional 93 | Whether to exponentiate each batch of log probabilities. Default is 94 | False. 95 | 96 | batch_size: int, optional 97 | The number of examples in each batch to evaluate at a time. Default 98 | is 200. 99 | 100 | 101 | Returns 102 | ------- 103 | results: torch.tensor 104 | The results of applying the function to the tensor. 105 | """ 106 | 107 | n = logps.shape[0] 108 | results = torch.empty(*logps.shape[:2]) 109 | 110 | for start in range(0, n, batch_size): 111 | end = start + batch_size 112 | 113 | logps_ = logps[start:end] 114 | true_counts_ = true_counts[start:end] 115 | 116 | if smooth_predictions: 117 | logps_ = torch.exp(logps_) 118 | logps_ = smooth_gaussian1d(logps_, kernel_sigma, kernel_width) 119 | 120 | if exponentiate_logps == False: 121 | logps_ = torch.log(logps_) 122 | else: 123 | if exponentiate_logps: 124 | logps_ = torch.exp(logps_) 125 | 126 | if smooth_true: 127 | true_counts_ = smooth_gaussian1d(true_counts_, kernel_sigma, kernel_width) 128 | 129 | results[start:end] = f(logps_, true_counts_) 130 | 131 | return results 132 | 133 | def _kl_divergence(probs1, probs2): 134 | """ 135 | Computes the KL divergence in the last dimension of `probs1` and `probs2` 136 | as KL(P1 || P2). `probs1` and `probs2` must be the same shape. For example, 137 | if they are both A x B x L arrays, then the KL divergence of corresponding 138 | L-arrays will be computed and returned in an A x B array. Does not 139 | renormalize the arrays. If probs2[i] is 0, that value contributes 0. 140 | """ 141 | 142 | idxs = ((probs1 != 0) & (probs2 != 0)) 143 | quot_ = torch.divide(probs1, probs2) 144 | 145 | quot = torch.ones_like(probs1) 146 | quot[idxs] = quot_[idxs] 147 | return torch.sum(probs1 * torch.log(quot), dim=-1) 148 | 149 | def jensen_shannon_distance(logps, true_counts): 150 | """ 151 | Computes the Jesnsen-Shannon distance in the last dimension of `probs1` and 152 | `probs2`. `probs1` and `probs2` must be the same shape. For example, if they 153 | are both A x B x L arrays, then the KL divergence of corresponding L-arrays 154 | will be computed and returned in an A x B array. This will renormalize the 155 | arrays so that each subarray sums to 1. If the sum of a subarray is 0, then 156 | the resulting JSD will be NaN. 157 | """ 158 | # Renormalize both distributions, and if the sum is NaN, put NaNs all around 159 | 160 | probs1 = torch.exp(logps) 161 | probs1_sum = torch.sum(probs1, dim=-1, keepdims=True) 162 | probs1 = torch.divide(probs1, probs1_sum, out=torch.zeros_like(probs1)) 163 | 164 | probs2_sum = torch.sum(true_counts, dim=-1, keepdims=True) 165 | probs2 = torch.divide(true_counts, probs2_sum, out=torch.zeros_like(true_counts)) 166 | 167 | mid = 0.5 * (probs1 + probs2) 168 | return 0.5 * (_kl_divergence(probs1, mid) + _kl_divergence(probs2, mid)) 169 | 170 | def pearson_corr(arr1, arr2): 171 | """The Pearson correlation between two tensors across the last axis. 172 | 173 | Computes the Pearson correlation in the last dimension of `arr1` and `arr2`. 174 | `arr1` and `arr2` must be the same shape. For example, if they are both 175 | A x B x L arrays, then the correlation of corresponding L-arrays will be 176 | computed and returned in an A x B array. 177 | 178 | Parameters 179 | ---------- 180 | arr1: torch.tensor 181 | One of the tensor to correlate. 182 | 183 | arr2: torch.tensor 184 | The other tensor to correlation. 185 | 186 | Returns 187 | ------- 188 | correlation: torch.tensor 189 | The correlation for each element, calculated along the last axis. 190 | """ 191 | 192 | mean1 = torch.mean(arr1, axis=-1).unsqueeze(-1) 193 | mean2 = torch.mean(arr2, axis=-1).unsqueeze(-1) 194 | dev1, dev2 = arr1 - mean1, arr2 - mean2 195 | 196 | sqdev1, sqdev2 = torch.square(dev1), torch.square(dev2) 197 | numer = torch.sum(dev1 * dev2, axis=-1) # Covariance 198 | var1, var2 = torch.sum(sqdev1, axis=-1), torch.sum(sqdev2, axis=-1) # Variances 199 | denom = torch.sqrt(var1 * var2) 200 | 201 | # Divide numerator by denominator, but use 0 where the denominator is 0 202 | correlation = torch.zeros_like(numer) 203 | correlation[denom != 0] = numer[denom != 0] / denom[denom != 0] 204 | return correlation 205 | 206 | 207 | def spearman_corr(arr1, arr2): 208 | """The Spearman correlation between two tensors across the last axis. 209 | 210 | Computes the Spearman correlation in the last dimension of `arr1` and `arr2`. 211 | `arr1` and `arr2` must be the same shape. For example, if they are both 212 | A x B x L arrays, then the correlation of corresponding L-arrays will be 213 | computed and returned in an A x B array. 214 | 215 | A dense ordering is used and ties are broken based on position in the 216 | tensor. 217 | 218 | Parameters 219 | ---------- 220 | arr1: torch.tensor 221 | One of the tensor to correlate. 222 | 223 | arr2: torch.tensor 224 | The other tensor to correlation. 225 | 226 | Returns 227 | ------- 228 | correlation: torch.tensor 229 | The correlation for each element, calculated along the last axis. 230 | """ 231 | 232 | ranks1 = arr1.argsort().argsort().type(torch.float32) 233 | ranks2 = arr2.argsort().argsort().type(torch.float32) 234 | return pearson_corr(ranks1, ranks2) 235 | 236 | 237 | def mean_squared_error(arr1, arr2): 238 | """The mean squared error between two tensors averaged along the last axis. 239 | 240 | Computes the element-wise squared error between two tensors and averages 241 | these across the last dimension. `arr1` and `arr2` must be the same shape. 242 | For example, if they are both A x B x L arrays, then the correlation of 243 | corresponding L-arrays will be computed and returned in an A x B array. 244 | 245 | Parameters 246 | ---------- 247 | arr1: torch.tensor 248 | A tensor of values. 249 | 250 | arr2: torch.tensor 251 | Another tensor of values. 252 | 253 | Returns 254 | ------- 255 | mse: torch.tensor 256 | The L2 distance between two tensors. 257 | """ 258 | 259 | return torch.mean(torch.square(arr1 - arr2), axis=-1) 260 | 261 | def calculate_performance_measures(logps, true_counts, pred_log_counts, 262 | kernel_sigma=7, kernel_width=81, smooth_true=False, 263 | smooth_predictions=False, measures=None): 264 | """ 265 | Computes some evaluation metrics on a set of positive examples, given the 266 | predicted profiles/counts, and the true profiles/counts. 267 | Arguments: 268 | `true_profs`: N x T x O x 2 array, where N is the number of 269 | examples, T is the number of tasks, and O is the output profile 270 | length; contains the true profiles for each for each task and 271 | strand, as RAW counts 272 | `log_pred_profs`: a N x T x O x 2 array, containing the predicted 273 | profiles for each task and strand, as LOG probabilities 274 | `true_counts`: a N x T x 2 array, containing the true total counts 275 | for each task and strand 276 | `log_pred_counts`: a N x T x 2 array, containing the predicted LOG total 277 | counts for each task and strand 278 | `smooth_true_profs`: if True, smooth the true profiles before computing 279 | JSD and correlations; true profiles will not be smoothed for any 280 | other metric 281 | `smooth_pred_profs`: if True, smooth the predicted profiles before 282 | computing NLL, cross entropy, JSD, and correlations; predicted 283 | profiles will not be smoothed for any other metric 284 | `print_updates`: if True, print out updates and runtimes 285 | Returns a dictionary with the following: 286 | A N x T-array of the average negative log likelihoods for the profiles 287 | (given predicted probabilities, the likelihood for the true counts), 288 | for each sample/task (strands averaged) 289 | A N x T-array of the average cross entropy for the profiles (given 290 | predicted probabilities, the likelihood for the true counts), for 291 | each sample/task (strands averaged) 292 | A N x T array of average Jensen-Shannon divergence between the predicted 293 | and true profiles (strands averaged) 294 | A N x T array of the Pearson correlation of the predicted and true (log) 295 | counts, for each sample/task (strands pooled) 296 | A N x T array of the Spearman correlation of the predicted and true 297 | (log) counts, for each sample/task (strands pooled) 298 | A N x T array of the mean squared error of the predicted and true (log) 299 | counts, for each sample/task (strands pooled) 300 | A T-array of the Pearson correlation of the (log) total counts, over all 301 | strands and samples 302 | A T-array of the Spearman correlation of the (log) total counts, over 303 | all strands and samples 304 | A T-array of the mean squared error of the (log) total counts, over all 305 | strands and samples 306 | """ 307 | 308 | measures_ = {} 309 | 310 | if measures is None or 'profile_mnll' in measures: 311 | measures_['profile_mnll'] = batched_smoothed_function(logps=logps, 312 | true_counts=true_counts, f=MNLLLoss, 313 | smooth_predictions=smooth_predictions, smooth_true=False, 314 | kernel_sigma=kernel_sigma, kernel_width=kernel_width) 315 | 316 | if measures is None or 'profile_jsd' in measures: 317 | measures_['profile_jsd'] = batched_smoothed_function(logps=logps, 318 | true_counts=true_counts, f=jensen_shannon_distance, 319 | smooth_predictions=smooth_predictions, smooth_true=smooth_true, 320 | kernel_sigma=kernel_sigma, kernel_width=kernel_width) 321 | 322 | if measures is None or 'profile_pearson' in measures: 323 | measures_['profile_pearson'] = batched_smoothed_function(logps=logps, 324 | true_counts=true_counts, f=pearson_corr, 325 | smooth_predictions=smooth_predictions, smooth_true=smooth_true, 326 | exponentiate_logps=True, kernel_sigma=kernel_sigma, 327 | kernel_width=kernel_width) 328 | 329 | if measures is None or 'profile_spearman' in measures: 330 | measures_['profile_spearman'] = batched_smoothed_function(logps=logps, 331 | true_counts=true_counts, f=spearman_corr, 332 | smooth_predictions=smooth_predictions, smooth_true=smooth_true, 333 | exponentiate_logps=True, kernel_sigma=kernel_sigma, 334 | kernel_width=kernel_width) 335 | 336 | 337 | # Total count correlations/MSE 338 | true_log_counts = torch.log(true_counts.sum(dim=-1)+1) 339 | 340 | if measures is None or 'count_pearson' in measures: 341 | measures_['count_pearson'] = pearson_corr(pred_log_counts.T, 342 | true_log_counts.T) 343 | 344 | if measures is None or 'count_spearman' in measures: 345 | measures_['count_spearman'] = spearman_corr(pred_log_counts.T, 346 | true_log_counts.T) 347 | 348 | if measures is None or 'count_mse' in measures: 349 | measures_['count_mse'] = mean_squared_error(pred_log_counts.T, 350 | true_log_counts.T) 351 | 352 | return measures_ -------------------------------------------------------------------------------- /chrombpnet: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # ChromBPNet command-line tool 3 | # Author: Jacob Schreiber 4 | 5 | import os 6 | import sys 7 | import numpy 8 | import torch 9 | import argparse 10 | import subprocess 11 | 12 | from bpnetlite.bpnet import BPNet 13 | from bpnetlite.chrombpnet import ChromBPNet 14 | 15 | from bpnetlite.io import PeakGenerator 16 | from bpnetlite.marginalize import marginalization_report 17 | 18 | from tangermeme.io import extract_loci 19 | 20 | import json 21 | 22 | desc = """ChromBPNet is a neural network that builds off the original BPNet 23 | architecture by explicitly learning bias in the signal tracks themselves. 24 | Specifically, for ATAC-seq and DNAse-seq experiments, the cutting enzymes 25 | have a soft sequence bias (though this is much stronger for Tn5, the 26 | enzyme for ATAC-seq). Accordingly, ChromBPNet is a pair of neural networks 27 | where one models the bias explicitly and one models the accessibility 28 | explicitly. This tool provides functionality for training the combination 29 | of the bias model and accessibility model and making predictions using it. 30 | After training, the accessibility model can be used using the `bpnet` 31 | tool.""" 32 | 33 | _help = """Must be either 'fit', 'predict', 'attribute', 'marginalize', 34 | or 'pipeline'.""" 35 | 36 | 37 | # Read in the arguments 38 | parser = argparse.ArgumentParser(description=desc) 39 | subparsers = parser.add_subparsers(help=_help, required=True, dest='cmd') 40 | 41 | train_parser = subparsers.add_parser("fit", help="Fit a ChromBPNet model.") 42 | train_parser.add_argument("-p", "--parameters", type=str, required=True, 43 | help="A JSON file containing the parameters for fitting the model.") 44 | 45 | predict_parser = subparsers.add_parser("predict", 46 | help="Make predictions using a trained ChromBPNet model.") 47 | predict_parser.add_argument("-p", "--parameters", type=str, required=True, 48 | help="A JSON file containing the parameters for making predictions.") 49 | 50 | attribute_parser = subparsers.add_parser("attribute", 51 | help="Calculate attributions using a trained ChromBPNet model.") 52 | attribute_parser.add_argument("-p", "--parameters", type=str, required=True, 53 | help="A JSON file containing the parameters for calculating attributions.") 54 | 55 | marginalize_parser = subparsers.add_parser("marginalize", 56 | help="Run marginalizations given motifs.") 57 | marginalize_parser.add_argument("-p", "--parameters", type=str, required=True, 58 | help="A JSON file containing the parameters for calculating attributions.") 59 | 60 | pipeline_parser = subparsers.add_parser("pipeline", 61 | help="Run each step on the given files.") 62 | pipeline_parser.add_argument("-p", "--parameters", type=str, required=True, 63 | help="A JSON file containing the parameters used for each step.") 64 | 65 | 66 | ### 67 | # Default Parameters 68 | ### 69 | 70 | default_fit_parameters = { 71 | 'n_filters': 256, 72 | 'n_layers': 8, 73 | 'profile_output_bias': True, 74 | 'count_output_bias': True, 75 | 'name': None, 76 | 'batch_size': 64, 77 | 'in_window': 2114, 78 | 'out_window': 1000, 79 | 'max_jitter': 128, 80 | 'reverse_complement': True, 81 | 'max_epochs': 50, 82 | 'validation_iter': 100, 83 | 'lr': 0.001, 84 | 'alpha': 10, 85 | 'beta': 0.5, 86 | 'early_stopping': None, 87 | 'verbose': False, 88 | 'bias_model': None, 89 | 90 | 'min_counts': None, 91 | 'max_counts': None, 92 | 93 | 'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 94 | 'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 95 | 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'], 96 | 'validation_chroms': ['chr8', 'chr10'], 97 | 98 | 'sequences': None, 99 | 'loci': None, 100 | 'negatives': None, 101 | 'signals': None, 102 | 'random_state': None, 103 | 104 | # Fit bias model 105 | 'bias_fit_parameters': { 106 | 'n_filters': None, 107 | 'n_layers': 4, 108 | 'alpha': None, 109 | 'max_counts': None, 110 | 'loci': None, 111 | 'verbose': None, 112 | 'random_state': None 113 | } 114 | } 115 | 116 | default_predict_parameters = { 117 | 'batch_size': 64, 118 | 'in_window': 2114, 119 | 'out_window': 1000, 120 | 'verbose': False, 121 | 'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 122 | 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 123 | 'chr20', 'chr21', 'chr22', 'chrX'], 124 | 'sequences': None, 125 | 'loci': None, 126 | 'model': None, 127 | 'profile_filename': 'y_profile.npz', 128 | 'counts_filename': 'y_counts.npz' 129 | } 130 | 131 | default_attribute_parameters = { 132 | 'batch_size': 64, 133 | 'in_window': 2114, 134 | 'out_window': 1000, 135 | 'verbose': False, 136 | 'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 137 | 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 138 | 'chr20', 'chr21', 'chr22', 'chrX'], 139 | 'sequences': None, 140 | 'loci': None, 141 | 'model': None, 142 | 'output': 'counts', 143 | 'ohe_filename': 'ohe.npz', 144 | 'attr_filename': 'attr.npz', 145 | 'n_shuffles':20, 146 | 'random_state':0, 147 | 'warning_threshold':1e-4, 148 | } 149 | 150 | default_marginalize_parameters = { 151 | 'batch_size': 64, 152 | 'in_window': 2114, 153 | 'out_window': 1000, 154 | 'verbose': False, 155 | 'chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr11', 156 | 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 157 | 'chr20', 'chr21', 'chr22', 'chrX'], 158 | 'sequences': None, 159 | 'motifs': None, 160 | 'loci': None, 161 | 'n_loci': None, 162 | 'shuffle': False, 163 | 'model': None, 164 | 'output_filename':'marginalize/', 165 | 'random_state':0, 166 | 'minimal': True 167 | } 168 | 169 | default_pipeline_parameters = { 170 | # Model architecture parameters 171 | 'n_filters': 256, 172 | 'n_layers': 8, 173 | 'profile_output_bias': True, 174 | 'count_output_bias': True, 175 | 'in_window': 2114, 176 | 'out_window': 1000, 177 | 'name': None, 178 | 'model': None, 179 | 'bias_model': None, 180 | 'accessibility_model': None, 181 | 'early_stopping': None, 182 | 'verbose': False, 183 | 184 | # Data parameters 185 | 'batch_size': 64, 186 | 'max_jitter': 128, 187 | 'reverse_complement': True, 188 | 'max_epochs': 50, 189 | 'validation_iter': 100, 190 | 'lr': 0.001, 191 | 'alpha': 10, 192 | 'beta': 0.5, 193 | 'min_counts': 0, 194 | 'max_counts': 99999999, 195 | 196 | 'sequences': None, 197 | 'loci': None, 198 | 'negatives': None, 199 | 'signals': None, 200 | 201 | 'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 202 | 'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 203 | 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'], 204 | 'validation_chroms': ['chr8', 'chr10'], 205 | 206 | # Fit bias model 207 | 'bias_fit_parameters': { 208 | 'n_filters': None, 209 | 'n_layers': 4, 210 | 'alpha': None, 211 | 'max_counts': None, 212 | 'loci': None, 213 | 'verbose': None, 214 | 'random_state': None 215 | }, 216 | 217 | # Fit accessibility model 218 | 'chrombpnet_fit_parameters': { 219 | 'batch_size': 64, 220 | 'sequences': None, 221 | 'loci': None, 222 | 'signals': None, 223 | 'verbose': None, 224 | 'random_state': None, 225 | }, 226 | 227 | # Predict parameters 228 | 'predict_parameters': { 229 | 'batch_size': 64, 230 | 'chroms': ['chr8', 'chr10'], 231 | 'profile_filename': None, 232 | 'counts_filename': None, 233 | 'sequences': None, 234 | 'loci': None, 235 | 'signals': None, 236 | 'verbose': None, 237 | }, 238 | 239 | 240 | # attribute parameters 241 | 'attribute_parameters': { 242 | 'batch_size': 64, 243 | 'chroms': ['chr8', 'chr10'], 244 | 'output': 'counts', 245 | 'loci': None, 246 | 'ohe_filename': None, 247 | 'attr_filename': None, 248 | 'n_shuffles': None, 249 | 'warning_threshold':1e-4, 250 | 'verbose': None, 251 | 'random_state': None 252 | }, 253 | 254 | # Modisco parameters 255 | 'modisco_motifs_parameters': { 256 | 'n_seqlets': 100000, 257 | 'output_filename': None, 258 | 'verbose': None 259 | }, 260 | 261 | # Modisco report parameters 262 | 'modisco_report_parameters': { 263 | 'motifs': None, 264 | 'output_folder': None, 265 | 'verbose': None 266 | }, 267 | 268 | # Marginalization parameters 269 | 'marginalize_parameters': { 270 | 'loci': None, 271 | 'n_loci': 100, 272 | 'shuffle': False, 273 | 'output_folder': None, 274 | 'motifs': None, 275 | 'minimal': True, 276 | 'verbose': None, 277 | 'random_state': None 278 | } 279 | } 280 | 281 | 282 | ### 283 | # Commands 284 | ### 285 | 286 | 287 | def merge_parameters(parameters, default_parameters): 288 | """Merge the provided parameters with the default parameters. 289 | 290 | 291 | Parameters 292 | ---------- 293 | parameters: str 294 | Name of the JSON folder with the provided parameters 295 | 296 | default_parameters: dict 297 | The default parameters for the operation. 298 | 299 | 300 | Returns 301 | ------- 302 | params: dict 303 | The merged set of parameters. 304 | """ 305 | 306 | with open(parameters, "r") as infile: 307 | parameters = json.load(infile) 308 | 309 | optional = ['bias_model', 'min_counts', 'max_counts', 'early_stopping'] 310 | 311 | for parameter, value in default_parameters.items(): 312 | if parameter not in parameters: 313 | if value is None and parameter not in optional: 314 | raise ValueError("Must provide value for '{}'".format(parameter)) 315 | 316 | parameters[parameter] = value 317 | 318 | return parameters 319 | 320 | 321 | # Pull the arguments 322 | args = parser.parse_args() 323 | 324 | ########## 325 | # FIT 326 | ########## 327 | 328 | if args.cmd == "fit": 329 | parameters = merge_parameters(args.parameters, default_fit_parameters) 330 | 331 | ### 332 | 333 | training_data = PeakGenerator( 334 | loci=parameters['loci'], 335 | sequences=parameters['sequences'], 336 | signals=parameters['signals'], 337 | controls=None, 338 | chroms=parameters['training_chroms'], 339 | in_window=parameters['in_window'], 340 | out_window=parameters['out_window'], 341 | max_jitter=parameters['max_jitter'], 342 | reverse_complement=parameters['reverse_complement'], 343 | min_counts=parameters['min_counts'], 344 | max_counts=parameters['max_counts'], 345 | random_state=parameters['random_state'], 346 | batch_size=parameters['batch_size'], 347 | verbose=parameters['verbose'] 348 | ) 349 | 350 | trimming = (parameters['in_window'] - parameters['out_window']) // 2 351 | 352 | if parameters['bias_model'] is None: 353 | counts = training_data.dataset.signals.sum(dim=(1, 2)) 354 | min_counts = torch.quantile(counts, 0.01).item() 355 | 356 | name = 'bias_fit_parameters' 357 | bias_fit_parameters = {key: parameters[key] for key in 358 | default_fit_parameters} 359 | for parameter, value in bias_fit_parameters[name].items(): 360 | if value is not None: 361 | bias_fit_parameters[parameter] = value 362 | if parameter == 'loci' and value is None: 363 | bias_fit_parameters[parameter] = parameters['negatives'] 364 | 365 | del bias_fit_parameters['negatives'], bias_fit_parameters['beta'] 366 | 367 | name = '{}.chrombpnet.bias.fit.json'.format(parameters['name']) 368 | bias_fit_parameters['max_counts'] = min_counts * parameters['beta'] 369 | bias_fit_parameters['name'] = parameters['name'] + '.bias' 370 | parameters['bias_model'] = bias_fit_parameters['name'] + '.torch' 371 | 372 | with open(name, 'w') as outfile: 373 | outfile.write(json.dumps(bias_fit_parameters, sort_keys=True, 374 | indent=4)) 375 | 376 | subprocess.run(["bpnet", "fit", "-p", name], check=True) 377 | 378 | 379 | if parameters['negatives'] is not None: 380 | training_data = PeakGenerator( 381 | loci=[parameters['loci'], parameters['negatives']], 382 | sequences=parameters['sequences'], 383 | signals=parameters['signals'], 384 | controls=None, 385 | chroms=parameters['training_chroms'], 386 | in_window=parameters['in_window'], 387 | out_window=parameters['out_window'], 388 | max_jitter=parameters['max_jitter'], 389 | reverse_complement=parameters['reverse_complement'], 390 | min_counts=parameters['min_counts'], 391 | max_counts=parameters['max_counts'], 392 | random_state=parameters['random_state'], 393 | batch_size=parameters['batch_size'], 394 | verbose=parameters['verbose'] 395 | ) 396 | 397 | valid_sequences, valid_signals = extract_loci( 398 | sequences=parameters['sequences'], 399 | signals=parameters['signals'], 400 | in_signals=None, 401 | loci=parameters['loci'], 402 | chroms=parameters['validation_chroms'], 403 | in_window=parameters['in_window'], 404 | out_window=parameters['out_window'], 405 | max_jitter=0, 406 | ignore=list('QWERYUIOPSDFHJKLZXVBNM'), 407 | verbose=parameters['verbose'] 408 | ) 409 | 410 | bias = torch.load(parameters['bias_model'], weights_only=False, map_location='cpu').cuda().eval() 411 | accessibility = BPNet(n_filters=parameters['n_filters'], 412 | n_layers=parameters['n_layers'], n_control_tracks=0, n_outputs=1, 413 | alpha=parameters['alpha'], 414 | name=parameters['name'] + '.accessibility', 415 | trimming=trimming).cuda() 416 | 417 | model = ChromBPNet(bias=bias, accessibility=accessibility, 418 | name=parameters['name']) 419 | 420 | optimizer = torch.optim.AdamW(model.parameters(), lr=parameters['lr']) 421 | 422 | model.fit(training_data, optimizer, X_valid=valid_sequences, 423 | y_valid=valid_signals, max_epochs=parameters['max_epochs'], 424 | validation_iter=parameters['validation_iter'], 425 | batch_size=parameters['batch_size']) 426 | 427 | 428 | ########## 429 | # PREDICT 430 | ########## 431 | 432 | elif args.cmd == 'predict': 433 | subprocess.run(["bpnet", "predict", "-p", args.parameters], check=True) 434 | 435 | 436 | ########## 437 | # ATTRIBUTE 438 | ########## 439 | 440 | elif args.cmd == 'attribute': 441 | subprocess.run(["bpnet", "attribute", "-p", args.parameters], check=True) 442 | 443 | 444 | ########## 445 | # MARGINALIZE 446 | ########## 447 | 448 | elif args.cmd == 'marginalize': 449 | subprocess.run(["bpnet", "marginalize", "-p", args.parameters], check=True) 450 | 451 | 452 | ########## 453 | # PIPELINE 454 | ########## 455 | 456 | elif args.cmd == 'pipeline': 457 | parameters = merge_parameters(args.parameters, default_pipeline_parameters) 458 | model_name = parameters['name'] 459 | 460 | # Step 1: Fit a BPNet model to the provided data 461 | if parameters['verbose']: 462 | print("Step 1: Fitting a ChromBPNet model") 463 | 464 | if parameters['model'] is None: 465 | name = '{}.chrombpnet.fit.json'.format(parameters['name']) 466 | parameters['model'] = parameters['name'] + '.torch' 467 | 468 | fit_parameters = {key: parameters[key] for key in 469 | default_fit_parameters} 470 | for parameter, value in parameters['chrombpnet_fit_parameters'].items(): 471 | if value is not None: 472 | fit_parameters[parameter] = value 473 | 474 | for parameter, value in parameters['bias_fit_parameters'].items(): 475 | if value is not None: 476 | fit_parameters['bias_fit_parameters'][parameter] = value 477 | 478 | with open(name, 'w') as outfile: 479 | outfile.write(json.dumps(fit_parameters, sort_keys=True, indent=4)) 480 | 481 | subprocess.run(["chrombpnet", "fit", "-p", name], check=True) 482 | 483 | 484 | if parameters['bias_model'] is None: 485 | parameters['bias_model'] = model_name + '.bias.torch' 486 | 487 | if parameters['accessibility_model'] is None: 488 | parameters['accessibility_model'] = (model_name + 489 | '.accessibility.torch') 490 | 491 | del parameters['bias_fit_parameters'] 492 | del parameters['chrombpnet_fit_parameters'] 493 | 494 | # Run pipeline with ChromBPNet model 495 | name = '{}.chrombpnet.pipeline.json'.format(parameters['name']) 496 | with open(name, 'w') as outfile: 497 | outfile.write(json.dumps(parameters, sort_keys=True, indent=4)) 498 | 499 | subprocess.run(["bpnet", "pipeline", "-p", name], check=True) 500 | 501 | 502 | # Run pipeline with accessibility model 503 | name = '{}.chrombpnet.accessibility.pipeline.json'.format( 504 | model_name) 505 | 506 | parameters['model'] = parameters['accessibility_model'] 507 | parameters['name'] = model_name + '.accessibility' 508 | 509 | with open(name, 'w') as outfile: 510 | outfile.write(json.dumps(parameters, sort_keys=True, indent=4)) 511 | 512 | subprocess.run(["bpnet", "pipeline", "-p", name], check=True) 513 | 514 | 515 | # Run pipeline with bias model 516 | name = '{}.chrombpnet.bias.pipeline.json'.format(model_name) 517 | 518 | parameters['model'] = parameters['bias_model'] 519 | parameters['name'] = model_name + '.bias' 520 | 521 | with open(name, 'w') as outfile: 522 | outfile.write(json.dumps(parameters, sort_keys=True, indent=4)) 523 | 524 | subprocess.run(["bpnet", "pipeline", "-p", name], check=True) 525 | 526 | -------------------------------------------------------------------------------- /example_jsons/README.md: -------------------------------------------------------------------------------- 1 | # Example JSON Parameters 2 | 3 | BPNet and ChromBPNet have many parameters for defining the architecture, learning process, and data. Rather than passing all of these parameters in through the command-line one should pass them in through a JSON. An additional benefit of this format is that, after running the command one has a complete log of the parameters used to generate the results. 4 | 5 | Default parameters are specified in the command-line tools themselves. When a value is not provided in the parameter JSON the subcommand will fall back on those default values. Potentially, this can significantly reduce the size of the JSONs you need to specify. 6 | 7 | ### BPNet Fit Parameters 8 | 9 | ``` 10 | { 11 | "n_filters": 64, # Number of filters in the convolutions 12 | "n_layers": 8, # Number of dilated residual convolutions between the initial, and final, layers. 13 | "profile_output_bias": true, # Whether to include a bias term in the profile head 14 | "count_output_bias": true, # Whether to include a bias term in the count head 15 | "name": "example", # Name of the model, primarily used to auto-generate output names if not provided 16 | 17 | "batch_size": 64, # Batch size to use for training and validation 18 | "in_window": 2114, # Length of the input window 19 | "out_window": 1000, # Length of the output window 20 | "max_jitter": 128, # Maximum amount of jitter when generating training examples 21 | "reverse_complement": true, # Whether to randomly RC half of the training examples 22 | "max_epochs": 50, # The maximum number of epochs to train for 23 | "validation_iter": 100, # The number of batches to train on before calculating validation set performance 24 | "lr": 0.001, # Learning rate of the AdamW optimizer 25 | "alpha": 1, # Weight of the count-loss in the total loss. 26 | "verbose": false, # Whether to print out a log to the terminal during training 27 | 28 | "min_counts": 0, # Ensure that each training example has at least this number of counts 29 | "max_counts": 99999999, # Ensure that each training example has no more than this number of counts 30 | 31 | # Chromosomes to train on 32 | "training_chroms": ["chr2", "chr3", "chr4" "chr5", "chr6", "chr7", 33 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", 34 | "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 35 | 36 | # Chromosomes to validate on 37 | "validation_chroms": ["chr8", "chr10"], 38 | 39 | "sequences":"../../oak/common/hg38/hg38.fa", # FASTA file of the genome to train on 40 | "loci":"../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", # Loci to train and validate on (can be a list) 41 | 42 | # A list of bigWig files to extract signal from -- each element in the list corresponds to one model output 43 | "signals":[ 44 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_plus.bigWig", 45 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_minus.bigWig" 46 | ], 47 | 48 | # An optional list of bigWig files containing control signals 49 | "controls":[ 50 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_plus.bigWig", 51 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_minus.bigWig" 52 | ], 53 | "random_state": 0 # A seed to control parameter initialization and data generation 54 | } 55 | ``` 56 | 57 | ### BPNet Predict 58 | 59 | ``` 60 | { 61 | "batch_size":64, # Batch size to use for training and validation 62 | "in_window":2114, # Length of the input window 63 | "out_window":1000, # Length of the output window 64 | "verbose":true, # Whether to print out a log to the terminal during training 65 | "chroms":["chr8", "chr10"], # Chromosomes whose peaks to make predictions for 66 | "model":"bpnet.64.8.torch", # Model to use to make the predictions 67 | 68 | "sequences":"../../oak/common/hg38/hg38.fa", # FASTA file of the genome to train on 69 | "loci":"../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", # Loci to train and validate on (can be a list) 70 | 71 | # An optional list of bigWig files containing control signals 72 | "controls": [ 73 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_plus.bigWig", 74 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_minus.bigWig" 75 | ], 76 | 77 | "profile_filename": "y_profile.npz", # The name of the file to store profile predictions 78 | "counts_filename": "y_counts.npz" # The name of the file to store count predictions 79 | } 80 | ``` 81 | 82 | ### BPNet Attribute 83 | 84 | ``` 85 | { 86 | "batch_size": 64, # Batch size to use for training and validation 87 | "in_window": 2114, # Length of the input window 88 | "out_window": 1000, # Length of the output window 89 | "verbose": true, # Whether to print out a progress bar during attribution 90 | "chroms": ["chr8", "chr10"], # Chromosomes whose peaks to make predictions for 91 | "model":"bpnet.64.8.torch", # Model to use for calculating attributions 92 | 93 | "sequences":"../../oak/common/hg38/hg38.fa", # FASTA file of the genome to train on 94 | "loci":"../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", # Loci to attribute (can be a list) 95 | 96 | "output": "profile", # Which head to calculate attributions for 97 | "ohe_filename": "ohe.npz", # Filename to store one-hot encodings of the sequences 98 | "attr_filename": "attr.npz", # Filename to store DeepLIFT/SHAP values for the sequences 99 | "n_shuffles":20, # Number of GC-matched shuffles to perform 100 | "random_state":0 # A seed to control the shuffles 101 | } 102 | ``` 103 | 104 | ### BPNet Marginalize 105 | 106 | ``` 107 | { 108 | "batch_size": 64, # Batch size to use for training and validation 109 | "in_window": 2114, # Length of the input window 110 | "out_window": 1000, # Length of the output window 111 | "verbose": true, # Whether to print out a log as motifs are inserted 112 | "chroms": ["chr8", "chr10"], # Chromosomes whose peaks to make predictions for 113 | 114 | "sequences": "../../oak/common/hg38/hg38.fa", # FASTA file of the genome to train on 115 | "loci": "../../tfatlas/processed_data/ENCSR000BGW/gc_neg_only.bed.gz", # Loci to use as background 116 | 117 | "motifs": "motifs.meme", # MEME file of motifs to insert into the sequences 118 | "n_loci": 100, # Number of background loci to use 119 | "shuffle": false, # Whether to shuffle the loci extracted before choosing `n_loci` 120 | "model": "bpnet.64.8.torch", # Model to use to make the predictions 121 | "output_filename":"marginalize/", # Folder to store the results including the summary 122 | "random_state":0, # A seed to control the shuffling 123 | "minimal": true # Whether to produce a minimal report or report all columns 124 | } 125 | ``` 126 | 127 | ### BPNet Pipeline 128 | 129 | When running the pipeline command, JSONs for each of the steps in the pipeline will be constructed from merging the parameters in the first layer with those in the nested JSONs, e.g., in the `fit_parameters`. If provided, parameters in the nested JSON will always override those in the first layer. For example, if you want to make predictions on a different genome than one has trained on, you can pass a different `sequences` and `loci` keyword in the `predict_parameters` JSON, or if you want to use a different MEME file in the marginalization and TF-MoDISco step you can. 130 | 131 | ``` 132 | { 133 | "n_filters": 64, # Number of filters in the convolutions 134 | "n_layers": 8, # Number of dilated residual convolutions between the initial, and final, layers. 135 | "profile_output_bias": true, # Whether to include a bias term in the profile head 136 | "count_output_bias": true, # Whether to include a bias term in the count head 137 | "in_window": 2114, # Length of the input window 138 | "out_window": 1000, # Length of the output window 139 | "name": "spi1", # Name of the model, primarily used to auto-generate output names if not provided 140 | "model": "spi1.torch", # Name of the model to use for the remaining steps instead of fitting a new one 141 | "batch_size": 64, # Batch size to use for training and validation 142 | "max_jitter": 128, # Maximum amount of jitter when generating training examples 143 | "reverse_complement": true, # Whether to randomly RC half of the training examples 144 | "max_epochs": 50, # The maximum number of epochs to train for 145 | "validation_iter": 100, # The number of batches to train on before calculating validation set performance 146 | "lr": 0.001, # Learning rate of the AdamW optimizer 147 | "alpha": 10, # Weight of the count-loss in the total loss. 148 | "verbose": true, # Whether to print out a log to the terminal during training 149 | "min_counts": 0, # Ensure that each training example has at least this number of counts 150 | "max_counts": 99999999, # Ensure that each training example has no more than this number of counts 151 | "sequences": "../../../oak/common/hg38/hg38.fa", # FASTA file of the genome to train on 152 | "loci": null, # Loci to use in other steps (can be a list) 153 | 154 | # A list of bigWig files to extract signal from -- each element in the list corresponds to one model output 155 | "signals": [ 156 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_plus.bigWig", 157 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_minus.bigWig" 158 | ], 159 | 160 | # An optional list of bigWig files containing control signals 161 | "controls":[ 162 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_control_plus.bigWig", 163 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_control_minus.bigWig" 164 | ], 165 | 166 | # Parameters to pass into the fit step 167 | "fit_parameters": { 168 | "batch_size": 64, # Batch size to use for training and validation 169 | 170 | # Chromosomes to train on 171 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 172 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", 173 | "chr17", "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 174 | 175 | # Chromosomes to validate on 176 | "validation_chroms": ["chr8", "chr10"], 177 | "sequences": null, # FASTA file of the genome to train on 178 | 179 | # Loci to train and validate on (can be a list) 180 | "loci": [ 181 | "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 182 | "../../../tfatlas/processed_data/ENCSR000BGQ/gc_neg_only.bed.gz" 183 | ], 184 | "signals": null, # A list of bigWig files to extract signal from 185 | "controls": null, # An optional list of bigWig files containing control signals 186 | "verbose": null, # Whether to print out a log to the terminal during training 187 | "random_state": null # A seed to control parameter initialization and data generation 188 | }, 189 | 190 | # Parameters to pass into the predict step 191 | "predict_parameters": { 192 | "batch_size": 64, # Batch size to use for training and validation 193 | "chroms": ["chr8", "chr10"], # Chromosomes to make predictions for 194 | "profile_filename": null, # The name of the file to store profile predictions 195 | "counts_filename": null, # The name of the file to store count predictions 196 | "sequences": null, # FASTA file of the genome to make predictions for 197 | 198 | # Loci to make predictions for (can be a list) 199 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 200 | 201 | "sequences": null, # A list of bigWig files to extract signal from 202 | "controls": null, # An optional list of bigWig files containing control signals 203 | "verbose": null # Whether to print out a log to the terminal during training 204 | }, 205 | 206 | # Parameters to pass into the attribute step 207 | "attribute_parameters": { 208 | "batch_size": 1, # Batch size to use for training and validation 209 | "output": "profile", # Which head to calculate attributions for 210 | "chroms": ["chr8"], # Chromosomes to make predictions for 211 | 212 | # Loci to make attributions for (can be a list) 213 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 214 | 215 | "ohe_filename": null, # Filename to store one-hot encodings of the sequences 216 | "attr_filename": null, # Filename to store DeepLIFT/SHAP values for the sequences 217 | "n_shuffles": 20, # Number of GC-matched shuffles to perform 218 | "random_state": null, # A seed to control the shuffles 219 | "verbose": null # Whether to print out a progress bar during attribution 220 | 221 | }, 222 | 223 | # Parameters to pass into the `modisco motifs` subcommand 224 | "modisco_motifs_parameters": { 225 | "n_seqlets": 100000, # Maximum number of seqlets to use 226 | "output_filename": null, # Filename to store the clustering results 227 | "verbose": null # Whether to print out a progress bar during TF-MoDISco 228 | }, 229 | 230 | # Parameters to pass into the `modisco report` subcommand 231 | "modisco_report_parameters": { 232 | "motifs": "motifs.meme", # MEME file of motifs to compare the found motifs to 233 | "output_folder": null, # Folder to store image outputs and report 234 | }, 235 | 236 | # Parameters to pass into the marginalization step 237 | "marginalize_parameters": { 238 | # Loci to make predictions for (can be a list) 239 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/gc_neg_only.bed.gz", 240 | "n_loci": 100, # Number of background loci to use 241 | "shuffle": false, # Whether to shuffle the loci extracted before choosing `n_loci` 242 | "random_state": null, # A seed to control the shuffling 243 | "output_folder": null, # Folder to store the results including the summary 244 | "motifs": "ctcf.gata2.sp1.meme", # MEME file of motifs to insert into the sequences 245 | "minimal": true, # Whether to produce a minimal report or report all columns 246 | "verbose": null # Whether to print out a log to the terminal as motifs are inserted 247 | } 248 | } 249 | ``` 250 | 251 | ## ChromBPNet Fit 252 | 253 | ``` 254 | { 255 | "n_filters": 64, # Number of filters in the convolutions 256 | "n_layers": 8, # Number of dilated residual convolutions between the initial, and final, layers. 257 | "profile_output_bias": true, # Whether to include a bias term in the profile head 258 | "count_output_bias": true, # Whether to include a bias term in the count head 259 | "name": "atac" # Name of the model 260 | 261 | "batch_size": 64, # Batch size to use for training and validation 262 | "in_window": 2114, # Length of the input window 263 | "out_window": 1000, # Length of the output window 264 | "max_jitter": 128, # Maximum amount of jitter when generating training examples 265 | "reverse_complement": true, # Whether to randomly RC half of the training examples 266 | "max_epochs": 50, # The maximum number of epochs to train for 267 | "validation_iter": 100, # The number of batches to train on before calculating validation set performance 268 | "lr": 0.001, # Learning rate of the AdamW optimizer 269 | "alpha": 10, # Weight of the count-loss in the total loss. 270 | "beta": 0.5, # Multiplier on the minimum read count in peaks to use when training a bias model 271 | 272 | "min_counts": 0, # Ensure that each training example has at least this number of counts 273 | "max_counts": 99999999, # Ensure that each training example has no more than this number of counts 274 | 275 | # Loci to train and validate on (can be a list) 276 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/downloads/peaks.bed.gz", 277 | 278 | # Negatives to train on 279 | "negatives": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 280 | 281 | # FASTA file of the genome to train on 282 | "sequences": "../../../oak/common/hg38/hg38.fa", 283 | 284 | # A list of bigWig files to extract signal from -- each element in the list corresponds to one model output 285 | "signals": [ 286 | "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/bigWigs/ENCSR637XSC.bigWig" 287 | ], 288 | 289 | # If you've already trained a bias model, put the filepath in here 290 | "bias_model": null, 291 | 292 | # Parameters for training the bias model 293 | "bias_fit_parameters": { 294 | "alpha": null, 295 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 296 | "max_counts": null, 297 | "n_filters": null, 298 | "n_layers": 4, 299 | "random_state": null, 300 | "verbose": null 301 | }, 302 | 303 | 304 | # Chromosomes to train on 305 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 306 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", 307 | "chr17", "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 308 | 309 | # Chromosomes to validate on 310 | "validation_chroms": ["chr8", "chr10"], 311 | 312 | "verbose": true, # # Whether to print out a log to the terminal during training 313 | "random_state": 0 # A random state to control parameter initialization and data generation 314 | } 315 | ``` 316 | 317 | ## ChromBPNet Pipeline 318 | 319 | When running the pipeline command, JSONs for each of the steps in the pipeline will be constructed from merging the parameters in the first layer with those in the nested JSONs, e.g., in the `fit_parameters`. If provided, parameters in the nested JSON will always override those in the first layer. For example, if you want to make predictions on a different genome than one has trained on, you can pass a different `sequences` and `loci` keyword in the `predict_parameters` JSON, or if you want to use a different MEME file in the marginalization and TF-MoDISco step you can. 320 | 321 | ``` 322 | { 323 | "n_filters": 64, # Number of filters in the convolutions 324 | "n_layers": 8, # Number of dilated residual convolutions between the initial, and final, layers. 325 | "profile_output_bias": true, # Whether to include a bias term in the profile head 326 | "count_output_bias": true, # Whether to include a bias term in the count head 327 | 328 | "batch_size": 64, # Batch size to use for training and validation 329 | "in_window": 2114, # Length of the input window 330 | "out_window": 1000, # Length of the output window 331 | "max_jitter": 128, # Maximum amount of jitter when generating training examples 332 | "reverse_complement": true, # Whether to randomly RC half of the training examples 333 | "max_epochs": 50, # The maximum number of epochs to train for 334 | "validation_iter": 100, # The number of batches to train on before calculating validation set performance 335 | "lr": 0.001, # Learning rate of the AdamW optimizer 336 | "alpha": 10, # Weight of the count-loss in the total loss. 337 | "beta": 0.5, # Multiplier on the minimum read count in peaks to use when training a bias model 338 | 339 | "min_counts": 0, # Ensure that each training example has at least this number of counts 340 | "max_counts": 99999999, # Ensure that each training example has no more than this number of counts 341 | 342 | "name": "atac", # Name to use for default filenames 343 | "model": "atac.torch", # Name of the full model if already trained 344 | "bias_model": "atac.bias.torch", # Name of the bias model if already trained 345 | "accessibility_model": "atac.accessibility.torch", # Name of the accessibility model if already trained 346 | 347 | "verbose": true, # Whether to print logs to the terminal 348 | "random_state": 0, # Seed to use for the steps 349 | 350 | # Loci to train and validate on (can be a list) 351 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/downloads/peaks.bed.gz", 352 | 353 | # Negatives to train on 354 | "negatives": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 355 | 356 | # FASTA file of the genome to train on 357 | "sequences": "../../../oak/common/hg38/hg38.fa", 358 | 359 | # A list of bigWig files to extract signal from -- each element in the list corresponds to one model output 360 | "signals": [ 361 | "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/bigWigs/ENCSR637XSC.bigWig" 362 | ], 363 | 364 | # Chromosomes to train on 365 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 366 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", 367 | "chr17", "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 368 | 369 | # Chromosomes to validate on 370 | "validation_chroms": ["chr8", "chr10"], 371 | 372 | # Parameters to use in the bias fit step 373 | "bias_fit_parameters": { 374 | "n_filters": null, # Number of filters in the convolutions 375 | "n_layers": 4, # Number of dilated residual convolutions between the initial, and final, layers. 376 | "alpha": null, # Weight of the count-loss in the total loss. 377 | "max_counts": null, # Ensure that each training example has no more than this number of counts 378 | 379 | # Loci to train and validate on (can be a list) 380 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 381 | 382 | "verbose": null, # Whether to print logs to the terminal 383 | "random_state": null # A random state to control parameter initialization and data generation 384 | }, 385 | 386 | # Parameters to use in the fit step 387 | "fit_parameters": { 388 | "batch_size": null, # Batch size to use for training and validation 389 | "sequences": null, # FASTA file of the genome to train on 390 | "loci": null, # Loci to train and validate on (can be a list) 391 | "signals": null, # A list of bigWig files to extract signal from 392 | "verbose": null, # Whether to print logs to the terminal 393 | "random_state": null # A random state to control parameter initialization and data generation 394 | }, 395 | 396 | # Parameters to pass into the predict step 397 | "predict_parameters": { 398 | "batch_size": 64, # Batch size to use for training and validation 399 | "chroms": ["chr8", "chr10"], # Chromosomes to make predictions for 400 | "profile_filename": null, # The name of the file to store profile predictions 401 | "counts_filename": null, # The name of the file to store count predictions 402 | "sequences": null, # FASTA file of the genome to make predictions for 403 | 404 | # Loci to make predictions for (can be a list) 405 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/downloads/peaks.bed.gz", 406 | 407 | "signals": null, # A list of bigWig files to extract signal from 408 | "controls": null, # An optional list of bigWig files containing control signals 409 | "verbose": null # Whether to print out a log to the terminal during training 410 | }, 411 | 412 | # Parameters to pass into the attribute step 413 | "attribute_parameters": { 414 | "batch_size": 1, # Batch size to use for training and validation 415 | "output": "profile", # Which head to calculate attributions for 416 | "chroms": ["chr8"], # Chromosomes to make predictions for 417 | 418 | # Loci to make attributions for (can be a list) 419 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 420 | 421 | "ohe_filename": null, # Filename to store one-hot encodings of the sequences 422 | "attr_filename": null, # Filename to store DeepLIFT/SHAP values for the sequences 423 | "n_shuffles": 20, # Number of GC-matched shuffles to perform 424 | "random_state": null, # A seed to control the shuffles 425 | "verbose": null # Whether to print out a progress bar during attribution 426 | 427 | }, 428 | 429 | # Parameters to pass into the `modisco motifs` subcommand 430 | "modisco_motifs_parameters": { 431 | "n_seqlets": 100000, # Maximum number of seqlets to use 432 | "output_filename": null, # Filename to store the clustering results 433 | "verbose": null # Whether to print out a progress bar during TF-MoDISco 434 | }, 435 | 436 | # Parameters to pass into the `modisco report` subcommand 437 | "modisco_report_parameters": { 438 | "motifs": "motifs.meme", # MEME file of motifs to compare the found motifs to 439 | "output_folder": null, # Folder to store image outputs and report 440 | }, 441 | 442 | # Parameters to pass into the marginalization step 443 | "marginalize_parameters": { 444 | # Loci to make predictions for (can be a list) 445 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 446 | "n_loci": 100, # Number of background loci to use 447 | "shuffle": false, # Whether to shuffle the loci extracted before choosing `n_loci` 448 | "random_state": null, # A seed to control the shuffling 449 | "output_folder": null, # Folder to store the results including the summary 450 | "motifs": "ctcf.gata2.sp1.meme", # MEME file of motifs to insert into the sequences 451 | "minimal": true, # Whether to produce a minimal report or report all columns 452 | "verbose": null # Whether to print out a log to the terminal as motifs are inserted 453 | } 454 | } 455 | ``` 456 | -------------------------------------------------------------------------------- /example_jsons/bpnet_attribute_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "in_window": 2114, 4 | "out_window": 1000, 5 | "verbose": true, 6 | "chroms": ["chr8", "chr10"], 7 | 8 | "sequences":"../../oak/common/hg38/hg38.fa", 9 | "loci":"../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", 10 | "model":"bpnet.64.8.torch", 11 | "output":"count", 12 | 13 | "output": "profile", 14 | "ohe_filename": "ohe.npz", 15 | "attr_filename": "attr.npz", 16 | "n_shuffles":20, 17 | "random_state":0 18 | } 19 | -------------------------------------------------------------------------------- /example_jsons/bpnet_fit_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_filters": 64, 3 | "n_layers": 8, 4 | "profile_output_bias": true, 5 | "count_output_bias": true, 6 | "name": "example", 7 | "batch_size": 64, 8 | "in_window": 2114, 9 | "out_window": 1000, 10 | "max_jitter": 128, 11 | "reverse_complement": true, 12 | "max_epochs": 50, 13 | "validation_iter": 100, 14 | "lr": 0.001, 15 | "alpha": 1, 16 | "verbose": false, 17 | 18 | "min_counts": 0, 19 | "max_counts": 99999999, 20 | 21 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 22 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", 23 | "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 24 | "validation_chroms": ["chr8", "chr10"], 25 | 26 | "sequences":"../../oak/common/hg38/hg38.fa", 27 | "loci":"../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", 28 | "signals":[ 29 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_plus.bigWig", 30 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_minus.bigWig" 31 | ], 32 | "controls":[ 33 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_plus.bigWig", 34 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_minus.bigWig" 35 | ], 36 | "random_state": 0 37 | } 38 | -------------------------------------------------------------------------------- /example_jsons/bpnet_marginalize_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "in_window": 2114, 4 | "out_window": 1000, 5 | "verbose": true, 6 | "chroms": ["chr1", "chr2", "chr3", "chr5", "chr6", "chr7", 7 | "chr8", "chr9", "chr10", "chr12", "chr13", "chr14", "chr16", 8 | "chr18", "chr19", "chr20", "chr22"], 9 | "sequences": "../../oak/common/hg38/hg38.fa", 10 | "motifs": "motifs.meme", 11 | "loci": "../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", 12 | "n_loci": 100, 13 | "shuffle": false, 14 | "model": "bpnet.64.8.torch", 15 | "output_filename":"marginalize/", 16 | "random_state":0, 17 | "minimal": true 18 | } 19 | -------------------------------------------------------------------------------- /example_jsons/bpnet_pipeline_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_filters": 64, 3 | "n_layers": 8, 4 | "profile_output_bias": true, 5 | "count_output_bias": true, 6 | "in_window": 2114, 7 | "out_window": 1000, 8 | "name": "spi1", 9 | "model": "spi1.torch", 10 | "verbose": false, 11 | "batch_size": 64, 12 | "max_jitter": 128, 13 | "reverse_complement": true, 14 | "max_epochs": 5, 15 | "validation_iter": 100, 16 | "lr": 0.001, 17 | "alpha": 1, 18 | "verbose": true, 19 | "min_counts": 0, 20 | "max_counts": 99999999, 21 | "sequences": "../../../oak/common/hg38/hg38.fa", 22 | "loci": null, 23 | "signals": [ 24 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_plus.bigWig", 25 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_minus.bigWig" 26 | ], 27 | "controls":[ 28 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_control_plus.bigWig", 29 | "../../../tfatlas/processed_data/ENCSR000BGQ/ENCSR000BGQ_control_minus.bigWig" 30 | ], 31 | "fit_parameters": { 32 | "batch_size": 64, 33 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 34 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", 35 | "chr17", "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 36 | "validation_chroms": ["chr8", "chr10"], 37 | "sequences": null, 38 | "loci": [ 39 | "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 40 | "../../../tfatlas/processed_data/ENCSR000BGQ/gc_neg_only.bed.gz" 41 | ], 42 | "signals": null, 43 | "controls": null, 44 | "verbose": null, 45 | "random_state": null 46 | }, 47 | "predict_parameters": { 48 | "batch_size": 64, 49 | "chroms": ["chr8", "chr10"], 50 | "profile_filename": null, 51 | "counts_filename": null, 52 | "sequences": null, 53 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 54 | "signals": null, 55 | "controls": null, 56 | "verbose": null 57 | }, 58 | "attribute_parameters": { 59 | "batch_size": 1, 60 | "output": "profile", 61 | "chroms": ["chr8"], 62 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/peaks.bed.gz", 63 | "ohe_filename": null, 64 | "attr_filename": null, 65 | "n_shuffles": 20, 66 | "random_state": null, 67 | "verbose": null 68 | 69 | }, 70 | "modisco_motifs_parameters": { 71 | "n_seqlets": 100000, 72 | "output_filename": null, 73 | "verbose": null 74 | }, 75 | "modisco_report_parameters": { 76 | "motifs": "motifs.meme", 77 | "output_folder": null, 78 | "verbose": null 79 | }, 80 | "marginalize_parameters": { 81 | "loci": "../../../tfatlas/processed_data/ENCSR000BGQ/gc_neg_only.bed.gz", 82 | "n_loci": 100, 83 | "shuffle": false, 84 | "random_state": null, 85 | "output_folder": null, 86 | "motifs": "ctcf.gata2.sp1.meme", 87 | "minimal": true, 88 | "verbose": null 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /example_jsons/bpnet_predict_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":64, 3 | "in_window":2114, 4 | "out_window":1000, 5 | "verbose":true, 6 | "chroms":["chr8", "chr10"], 7 | "model":"bpnet.64.8.torch", 8 | 9 | "sequences":"../../oak/common/hg38/hg38.fa", 10 | "loci":"../../tfatlas/processed_data/ENCSR000BGW/peaks.bed.gz", 11 | "controls": [ 12 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_plus.bigWig", 13 | "../../tfatlas/processed_data/ENCSR000BGW/ENCSR000BGW_control_minus.bigWig" 14 | ], 15 | 16 | "profile_filename": "y_profile.npz", 17 | "counts_filename": "y_counts.npz" 18 | } -------------------------------------------------------------------------------- /example_jsons/chrombpnet_fit_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "alpha": 10, 3 | "batch_size": 64, 4 | "beta": 0.5, 5 | "bias_model": null, 6 | "controls": null, 7 | "count_output_bias": true, 8 | "in_window": 2114, 9 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/downloads/peaks.bed.gz", 10 | "lr": 0.001, 11 | "max_counts": 99999999, 12 | "max_epochs": 50, 13 | "max_jitter": 128, 14 | "min_counts": 0, 15 | "n_filters": 64, 16 | "n_layers": 8, 17 | "name": "atac", 18 | "negatives": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 19 | "out_window": 1000, 20 | "profile_output_bias": true, 21 | "random_state": 0, 22 | "reverse_complement": true, 23 | "sequences": "../../../oak/common/hg38/hg38.fa", 24 | "signals": [ 25 | "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/bigWigs/ENCSR637XSC.bigWig" 26 | ], 27 | "training_chroms": ["chr2", "chr3", "chr4" "chr5", "chr6", "chr7", 28 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", 29 | "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 30 | "validation_chroms": ["chr8", "chr10"], 31 | "validation_iter": 100, 32 | "verbose": true 33 | } -------------------------------------------------------------------------------- /example_jsons/chrombpnet_pipeline_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_filters": 256, 3 | "n_layers": 8, 4 | "profile_output_bias": true, 5 | "count_output_bias": true, 6 | "in_window": 2114, 7 | "out_window": 1000, 8 | "name": "atac", 9 | "model": "atac.torch", 10 | "bias_model": "atac.bias.torch", 11 | "accessibility_model": "atac.accessibility.torch", 12 | "verbose": true, 13 | "batch_size": 256, 14 | "max_jitter": 128, 15 | "reverse_complement": true, 16 | "max_epochs": 10, 17 | "validation_iter": 1000, 18 | "lr": 0.001, 19 | "alpha": 10, 20 | "beta": 0.5, 21 | "min_counts": 0, 22 | "max_counts": 99999999, 23 | "sequences": "../../../oak/common/hg38/hg38.fa", 24 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/downloads/peaks.bed.gz", 25 | "negatives": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 26 | "signals": [ 27 | "../../../chromatin-atlas/ATAC/ENCSR637XSC/preprocessing/bigWigs/ENCSR637XSC.bigWig" 28 | ], 29 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 30 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", 31 | "chr17", "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 32 | "validation_chroms": ["chr8", "chr10"], 33 | "random_state": 0, 34 | "bias_fit_parameters": { 35 | "n_filters": null, 36 | "n_layers": 4, 37 | "alpha": null, 38 | "max_counts": null, 39 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 40 | "verbose": null, 41 | "random_state": null 42 | }, 43 | "fit_parameters": { 44 | "batch_size": null, 45 | "sequences": null, 46 | "loci": null, 47 | "signals": null, 48 | "verbose": null, 49 | "random_state": null 50 | }, 51 | "predict_parameters": { 52 | "batch_size": null, 53 | "chroms": ["chr8", "chr10"], 54 | "profile_filename": null, 55 | "counts_filename": null, 56 | "sequences": null, 57 | "loci": null, 58 | "signals": null, 59 | "verbose": null 60 | }, 61 | "attribute_parameters": { 62 | "batch_size": 1, 63 | "chroms": ["chr8", "chr10"], 64 | "output": "profile", 65 | "loci": null, 66 | "ohe_filename": null, 67 | "attr_filename": null, 68 | "n_shuffles": 20, 69 | "verbose": null, 70 | "random_state": null 71 | }, 72 | "modisco_motifs_parameters": { 73 | "n_seqlets": 100000, 74 | "output_filename": null, 75 | "verbose": null 76 | }, 77 | "modisco_report_parameters": { 78 | "motifs": "motifs.meme", 79 | "output_folder": null, 80 | "verbose": null 81 | }, 82 | "marginalize_parameters": { 83 | "loci": "../../../chromatin-atlas/ATAC/ENCSR637XSC/negatives_data/negatives.bed", 84 | "n_loci": 100, 85 | "shuffle": false, 86 | "output_folder": null, 87 | "motifs": "ctcf.gata2.sp1.meme", 88 | "minimal": true, 89 | "verbose": null, 90 | "random_state": null 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /examples/ENCSR000AKO_minus.bigWig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmschrei/bpnet-lite/1fb8173b2828f4fcfe7a17f73fae13c295d0848f/examples/ENCSR000AKO_minus.bigWig -------------------------------------------------------------------------------- /examples/ENCSR000AKO_plus.bigWig: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jmschrei/bpnet-lite/1fb8173b2828f4fcfe7a17f73fae13c295d0848f/examples/ENCSR000AKO_plus.bigWig -------------------------------------------------------------------------------- /notebooks/bpnet/bpnet-fit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "78c668b3-89c5-4c62-b7eb-c38374eeecab", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#!pip install bpnet-lite" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "6b099a04-f82e-49e7-94a1-2ab5078c70e5", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "!bpnet fit -p bpnet_fit_example.json" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "9ce180e9-e80d-4e93-8474-56255ff5902c", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import pandas as pd\n", 31 | "import seaborn as sns\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "log = pd.read_table(\"example.log\")\n", 35 | "log" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "7c70a7d5-3030-4c07-8a87-4fa2a35c8605", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Training MNLL\", label=\"train\")\n", 46 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation MNLL\", label=\"val\")\n", 47 | "plt.title(\"MNLL Loss\")\n", 48 | "plt.show()\n", 49 | "\n", 50 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Training Count MSE\", label=\"train\")\n", 51 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation Count MSE\", label=\"val\")\n", 52 | "plt.title(\"Count MSE Loss\")\n", 53 | "plt.show()\n", 54 | "\n", 55 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation Profile Pearson\", color=\"green\", label=\"profile\")\n", 56 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation Profile Pearson\", color=\"brown\", label=\"count\")\n", 57 | "plt.title(\"Validation Pearson\")\n", 58 | "plt.show()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "cb887f1a-4b67-4c51-b32b-7ab8425f6a18", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "!bpnet predict -p bpnet_predict_example.json" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "id": "83c9607d-ac6c-4e34-a47e-e38c77619bda", 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "import numpy as np\n", 79 | "import torch\n", 80 | "\n", 81 | "y_profile = np.load(\"y_profile.npz\")\n", 82 | "print(list(y_profile.keys()))\n", 83 | "y_profile = y_profile[\"arr_0\"]\n", 84 | "\n", 85 | "y_counts = np.load(\"y_counts.npz\")\n", 86 | "print(list(y_counts.keys()))\n", 87 | "y_counts = y_counts[\"arr_0\"]\n", 88 | "\n", 89 | "print(y_profile.shape, y_counts.shape) # batch, strands, seq_len" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "4339fa48-c4db-4c2f-bea5-3be81e6319a1", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# get ground truth\n", 100 | "from tangermeme.io import extract_loci\n", 101 | "import json\n", 102 | "\n", 103 | "with open(\"bpnet_fit_example.json\", \"r\") as f:\n", 104 | " parameters = json.load(f)\n", 105 | "\n", 106 | "parameters['controls'] = None\n", 107 | "\n", 108 | "valid_data = extract_loci(\n", 109 | " sequences=parameters['sequences'],\n", 110 | " signals=parameters['signals'],\n", 111 | " in_signals=parameters['controls'],\n", 112 | " loci=parameters['loci'],\n", 113 | " chroms=parameters['validation_chroms'],\n", 114 | " in_window=parameters['in_window'],\n", 115 | " out_window=parameters['out_window'],\n", 116 | " max_jitter=0,\n", 117 | " ignore=list('QWERYUIOPSDFHJKLZXVBNM'),\n", 118 | " verbose=parameters['verbose']\n", 119 | ")\n", 120 | "print(valid_data[0].shape, valid_data[1].shape) # sequences and signals" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "ad9f98f1-baf8-45a5-bb75-23d321a0dcfb", 127 | "metadata": {}, 128 | "outputs": [], 129 | "source": [ 130 | "idx = 100\n", 131 | "\n", 132 | "def softmax(x):\n", 133 | " \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n", 134 | " e_x = np.exp(x - np.max(x))\n", 135 | " return e_x / e_x.sum(axis=0)\n", 136 | "\n", 137 | "# predictions\n", 138 | "y = y_profile[idx, 0, :]\n", 139 | "y = softmax(y)\n", 140 | "x = np.arange(len(y))\n", 141 | "sns.lineplot(x=x, y=y, label=\"plus\", color=\"blue\")\n", 142 | "\n", 143 | "y = y_profile[idx, 1, :]\n", 144 | "y = softmax(y)\n", 145 | "x = np.arange(len(y))\n", 146 | "sns.lineplot(x=x, y=-y, label=\"minus\", color=\"orange\")\n", 147 | "plt.title(\"Predictions\")\n", 148 | "plt.show()\n", 149 | "\n", 150 | "# ground truth\n", 151 | "y = valid_data[1][idx, 0, :]\n", 152 | "x = np.arange(len(y))\n", 153 | "sns.lineplot(x=x, y=y, label=\"plus\", color=\"blue\")\n", 154 | "\n", 155 | "y = valid_data[1][idx, 1, :]\n", 156 | "x = np.arange(len(y))\n", 157 | "sns.lineplot(x=x, y=-y, label=\"minus\", color=\"orange\")\n", 158 | "plt.title(\"Ground Truth\")\n", 159 | "plt.show()" 160 | ] 161 | } 162 | ], 163 | "metadata": { 164 | "kernelspec": { 165 | "display_name": "Python 3 (ipykernel)", 166 | "language": "python", 167 | "name": "python3" 168 | }, 169 | "language_info": { 170 | "codemirror_mode": { 171 | "name": "ipython", 172 | "version": 3 173 | }, 174 | "file_extension": ".py", 175 | "mimetype": "text/x-python", 176 | "name": "python", 177 | "nbconvert_exporter": "python", 178 | "pygments_lexer": "ipython3", 179 | "version": "3.8.17" 180 | } 181 | }, 182 | "nbformat": 4, 183 | "nbformat_minor": 5 184 | } 185 | -------------------------------------------------------------------------------- /notebooks/bpnet/bpnet_fit_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_filters": 64, 3 | "n_layers": 8, 4 | "profile_output_bias": true, 5 | "count_output_bias": true, 6 | "name": "example", 7 | "batch_size": 64, 8 | "in_window": 2114, 9 | "out_window": 1000, 10 | "max_jitter": 128, 11 | "reverse_complement": true, 12 | "max_epochs": 50, 13 | "validation_iter": 100, 14 | "lr": 0.001, 15 | "alpha": 1, 16 | "verbose": true, 17 | 18 | "min_counts": 0, 19 | "max_counts": 99999999, 20 | 21 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 22 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", 23 | "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 24 | "validation_chroms": ["chr8", "chr10"], 25 | 26 | "sequences":"/home/manglada/databases/data/GENCODE/GRCh38.p13.genome.fa", 27 | "loci":"../../examples/ENCSR000AKO.bed", 28 | "signals":[ 29 | "../../examples/ENCSR000AKO_plus.bigWig", 30 | "../../examples/ENCSR000AKO_minus.bigWig" 31 | ], 32 | "random_state": 0 33 | } 34 | -------------------------------------------------------------------------------- /notebooks/bpnet/bpnet_predict_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size":64, 3 | "in_window":2114, 4 | "out_window":1000, 5 | "verbose":true, 6 | "chroms":["chr8", "chr10"], 7 | "model":"example.final.torch", 8 | "sequences":"/home/manglada/databases/data/GENCODE/GRCh38.p13.genome.fa", 9 | "loci":"../../examples/ENCSR000AKO.bed", 10 | "profile_filename": "y_profile.npz", 11 | "counts_filename": "y_counts.npz" 12 | } -------------------------------------------------------------------------------- /notebooks/rbpnet/rbpnet-fit.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4470d365-8cb6-489f-9b75-78df0a3131c3", 6 | "metadata": {}, 7 | "source": [ 8 | "# Like RBPNet\n", 9 | "https://doi.org/10.1186/s13059-023-03015-7\n", 10 | "\n", 11 | "Changes:\n", 12 | "- \"n_filters\": 128\n", 13 | "- \"n_layers\": 9\n", 14 | "- \"batch_size\": 128\n", 15 | "- \"in_window\": 300\n", 16 | "- \"out_window\": 300\n", 17 | "- \"max_jitter\": 0" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "78c668b3-89c5-4c62-b7eb-c38374eeecab", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "#!pip install bpnet-lite" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "6b099a04-f82e-49e7-94a1-2ab5078c70e5", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "!bpnet fit -p rbpnet_fit_example.json" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "9ce180e9-e80d-4e93-8474-56255ff5902c", 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "import pandas as pd\n", 48 | "import seaborn as sns\n", 49 | "import matplotlib.pyplot as plt\n", 50 | "\n", 51 | "log = pd.read_table(\"example.log\")\n", 52 | "log" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "7c70a7d5-3030-4c07-8a87-4fa2a35c8605", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Training MNLL\", label=\"train\")\n", 63 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation MNLL\", label=\"val\")\n", 64 | "plt.title(\"MNLL Loss\")\n", 65 | "plt.show()\n", 66 | "\n", 67 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Training Count MSE\", label=\"train\")\n", 68 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation Count MSE\", label=\"val\")\n", 69 | "plt.title(\"Count MSE Loss\")\n", 70 | "plt.show()\n", 71 | "\n", 72 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation Profile Pearson\", color=\"green\", label=\"profile\")\n", 73 | "sns.lineplot(data=log, x=\"Epoch\", y=\"Validation Profile Pearson\", color=\"brown\", label=\"count\")\n", 74 | "plt.title(\"Validation Pearson\")\n", 75 | "plt.show()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "cb887f1a-4b67-4c51-b32b-7ab8425f6a18", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "!bpnet predict -p rbpnet_predict_example.json" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "83c9607d-ac6c-4e34-a47e-e38c77619bda", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "import numpy as np\n", 96 | "import torch\n", 97 | "\n", 98 | "y_profile = np.load(\"y_profile.npz\")\n", 99 | "print(list(y_profile.keys()))\n", 100 | "y_profile = y_profile[\"arr_0\"]\n", 101 | "\n", 102 | "y_counts = np.load(\"y_counts.npz\")\n", 103 | "print(list(y_counts.keys()))\n", 104 | "y_counts = y_counts[\"arr_0\"]\n", 105 | "\n", 106 | "print(y_profile.shape, y_counts.shape) # batch, strands, seq_len" 107 | ] 108 | }, 109 | { 110 | "cell_type": "code", 111 | "execution_count": null, 112 | "id": "4339fa48-c4db-4c2f-bea5-3be81e6319a1", 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "# get ground truth\n", 117 | "from tangermeme.io import extract_loci\n", 118 | "import json\n", 119 | "\n", 120 | "with open(\"rbpnet_fit_example.json\", \"r\") as f:\n", 121 | " parameters_fit = json.load(f)\n", 122 | " \n", 123 | "with open(\"rbpnet_predict_example.json\", \"r\") as f:\n", 124 | " parameters_predict = json.load(f)\n", 125 | "\n", 126 | "parameters_fit['controls'] = None\n", 127 | "\n", 128 | "valid_data = extract_loci(\n", 129 | " sequences=parameters_fit['sequences'],\n", 130 | " signals=parameters_fit['signals'],\n", 131 | " in_signals=parameters_fit['controls'],\n", 132 | " loci=parameters_fit['loci'],\n", 133 | " chroms=parameters_fit['validation_chroms'],\n", 134 | " in_window=parameters_predict['in_window'],\n", 135 | " out_window=parameters_predict['out_window'],\n", 136 | " max_jitter=0,\n", 137 | " ignore=list('QWERYUIOPSDFHJKLZXVBNM'),\n", 138 | " verbose=parameters_fit['verbose']\n", 139 | ")\n", 140 | "print(valid_data[0].shape, valid_data[1].shape) # sequences and signals" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "ad9f98f1-baf8-45a5-bb75-23d321a0dcfb", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "idx = 150\n", 151 | "\n", 152 | "def softmax(x):\n", 153 | " \"\"\"Compute softmax values for each sets of scores in x.\"\"\"\n", 154 | " e_x = np.exp(x - np.max(x))\n", 155 | " return e_x / e_x.sum(axis=0)\n", 156 | "\n", 157 | "# predictions\n", 158 | "y = y_profile[idx, 0, :]\n", 159 | "y = softmax(y)\n", 160 | "x = np.arange(len(y))\n", 161 | "sns.lineplot(x=x, y=y, label=\"plus\", color=\"blue\")\n", 162 | "\n", 163 | "y = y_profile[idx, 1, :]\n", 164 | "y = softmax(y)\n", 165 | "x = np.arange(len(y))\n", 166 | "sns.lineplot(x=x, y=-y, label=\"minus\", color=\"orange\")\n", 167 | "plt.title(\"Predictions\")\n", 168 | "plt.show()\n", 169 | "\n", 170 | "# ground truth\n", 171 | "y = valid_data[1][idx, 0, :]\n", 172 | "x = np.arange(len(y))\n", 173 | "sns.lineplot(x=x, y=y, label=\"plus\", color=\"blue\")\n", 174 | "\n", 175 | "y = valid_data[1][idx, 1, :]\n", 176 | "x = np.arange(len(y))\n", 177 | "sns.lineplot(x=x, y=-y, label=\"minus\", color=\"orange\")\n", 178 | "plt.title(\"Ground Truth\")\n", 179 | "plt.show()" 180 | ] 181 | }, 182 | { 183 | "cell_type": "code", 184 | "execution_count": null, 185 | "id": "d25da0a9-2411-49ad-87a3-2b2b8f164a4e", 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "x = y_counts[:,0]\n", 190 | "y = valid_data[1].sum(axis=1).sum(axis=1)\n", 191 | "sns.scatterplot(x=x, y=y, alpha=0.1)\n", 192 | "plt.ylabel(\"Ground Truth\")\n", 193 | "plt.xlabel(\"Predicted\")\n", 194 | "plt.title(\"Total Counts\")\n", 195 | "plt.show()" 196 | ] 197 | } 198 | ], 199 | "metadata": { 200 | "kernelspec": { 201 | "display_name": "Python 3 (ipykernel)", 202 | "language": "python", 203 | "name": "python3" 204 | }, 205 | "language_info": { 206 | "codemirror_mode": { 207 | "name": "ipython", 208 | "version": 3 209 | }, 210 | "file_extension": ".py", 211 | "mimetype": "text/x-python", 212 | "name": "python", 213 | "nbconvert_exporter": "python", 214 | "pygments_lexer": "ipython3", 215 | "version": "3.8.17" 216 | } 217 | }, 218 | "nbformat": 4, 219 | "nbformat_minor": 5 220 | } 221 | -------------------------------------------------------------------------------- /notebooks/rbpnet/rbpnet_fit_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "n_filters": 128, 3 | "n_layers": 9, 4 | "profile_output_bias": true, 5 | "count_output_bias": true, 6 | "name": "example", 7 | "batch_size": 128, 8 | "in_window": 400, 9 | "out_window": 300, 10 | "max_jitter": 0, 11 | "reverse_complement": true, 12 | "max_epochs": 50, 13 | "validation_iter": 100, 14 | "lr": 0.001, 15 | "alpha": 1, 16 | "verbose": true, 17 | 18 | "min_counts": 0, 19 | "max_counts": 99999999, 20 | 21 | "training_chroms": ["chr2", "chr3", "chr4", "chr5", "chr6", "chr7", 22 | "chr9", "chr11", "chr12", "chr13", "chr14", "chr15", "chr16", "chr17", 23 | "chr18", "chr19", "chr20", "chr21", "chr22", "chrX"], 24 | "validation_chroms": ["chr8", "chr10"], 25 | 26 | "sequences":"/home/manglada/databases/data/GENCODE/GRCh38.p13.genome.fa", 27 | "loci":"../../examples/ENCSR000AKO.bed", 28 | "signals":[ 29 | "../../examples/ENCSR000AKO_plus.bigWig", 30 | "../../examples/ENCSR000AKO_minus.bigWig" 31 | ], 32 | "random_state": 0 33 | } 34 | -------------------------------------------------------------------------------- /notebooks/rbpnet/rbpnet_predict_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "batch_size": 64, 3 | "in_window": 1300, 4 | "out_window": 1200, 5 | "verbose": true, 6 | "chroms": ["chr8", "chr10"], 7 | "model": "example.final.torch", 8 | "sequences": "/home/manglada/databases/data/GENCODE/GRCh38.p13.genome.fa", 9 | "loci": "../../examples/ENCSR000AKO.bed", 10 | "profile_filename": "y_profile.npz", 11 | "counts_filename": "y_counts.npz" 12 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='bpnet-lite', 5 | version='0.8.1', 6 | author='Jacob Schreiber', 7 | author_email='jmschreiber91@gmail.com', 8 | packages=['bpnetlite'], 9 | scripts=['bpnet', 'chrombpnet'], 10 | url='https://github.com/jmschrei/bpnet-lite', 11 | license='LICENSE.txt', 12 | description='bpnet-lite is a minimal implementation of BPNet, a neural network aimed at interpreting regulatory activity of the genome.', 13 | install_requires=[ 14 | "numpy >= 1.14.2", 15 | "scipy >= 1.0.0", 16 | "pandas >= 1.3.3", 17 | "torch >= 1.9.0", 18 | "h5py >= 3.7.0", 19 | "tqdm >= 4.64.1", 20 | "seaborn >= 0.11.2", 21 | "modisco-lite >= 2.0.0", 22 | "tangermeme >= 0.2.3", 23 | "bam2bw" 24 | ], 25 | ) 26 | --------------------------------------------------------------------------------