├── LICENSE ├── README.md ├── deepmod2 ├── docs ├── Example.md ├── Install.md ├── Training.md └── Usage.md ├── environment.yml ├── plot_utils ├── README.md ├── Signal_Plot_Examples.ipynb └── plot.py ├── src ├── detect.py ├── models.py ├── models │ ├── bilstm.cfg │ ├── bilstm │ │ ├── R10.4.1_4kHz_v3.5 │ │ ├── R10.4.1_4kHz_v4.1 │ │ ├── R10.4.1_5kHz_v4.3 │ │ ├── R10.4.1_5kHz_v5.0 │ │ └── R9.4.1 │ ├── transformer.cfg │ └── transformer │ │ ├── R10.4.1_4kHz_v3.5 │ │ ├── R10.4.1_4kHz_v4.1 │ │ ├── R10.4.1_5kHz_v4.3 │ │ ├── R10.4.1_5kHz_v5.0 │ │ └── R9.4.1 └── utils.py └── train ├── generate_features.py ├── train_models.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Wang Genomics Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepMod2 2 | DeepMod2 is a computational tool for detecting DNA 5mC methylation from Oxford Nanopore reads. It uses a BiLSTM model to predict per-read and per-site 5mC methylations for CpG sites and produces a methylation tagged BAM file. DeepMod2 can call methylation from POD5 and FAST5 files basecalled with Guppy or Dorado and provides models for R10.4.1 and R9.4.1 flowcells. 3 | 4 |

5 | 6 | DeepMod2 is distributed under the [MIT License by Wang Genomics Lab](https://wglab.mit-license.org/). 7 | 8 | ### Citing DeepMod2 9 | Ahsan, M.U., Gouru, A., Chan, J. et al. A signal processing and deep learning framework for methylation detection using Oxford Nanopore sequencing. Nat Commun 15, 1448 (2024). https://doi.org/10.1038/s41467-024-45778-y 10 | ## Installation 11 | Please refer to [Installation](https://github.com/WGLab/DeepMod2/blob/main/docs/Install.md) for how to install DeepMod2. 12 | 13 | ## Inference 14 | Quick usage guide for model inference: 15 | 1. Basecall your FAST5/POD5 files with Dorado (using `--emit-moves`) or Guppy (using `--bam_out --moves_out`) parameters to get a BAM file with move tables: 16 | ``` 17 | dorado basecaller MODEL INPUT_DIR --emit-moves > basecall.bam 18 | ``` 19 | Make sure to use the appropriate Guppy/Dorado model for your sequencing kit. You can supply a reference genome to Guppy/Dorado to get aligned BAM files, or use minimap2 to align these BAM files later. 20 | 2. (Optional but recommended) Align basecalled reads to a reference genome while retaining the move tables: 21 | ``` 22 | samtools fastq basecall.bam -T mv,ts | minimap2 -ax map-ont ref.fa - -y -t NUM_THREADS |samtools view -o aligned.bam 23 | ``` 24 | 3. Run DeepMod2 by providing BAM file and the folder containing FAST5 or POD5 signal files as inputs. You can provide reference FASTA file to get reference anchored methylation calls and per-site frequencies if the BAM file is aligned. Specify the model you want to use and the file type of raw signal files. Use multiple cores and/or GPUs for speedup. 25 | 26 | a) If using an aligned BAM file input: 27 | ``` 28 | python PATH_TO_DEEPMOD2_REPOSITORY/deepmod2 detect --bam reads.bam --input INPUT_DIR --model MODEL --file_type FILE_TYPE --threads NUM_THREADS --ref ref.fa --output MOD_CALLS 29 | ``` 30 | 31 | b) If using an unaligned BAM file input: 32 | ``` 33 | python PATH_TO_DEEPMOD2_REPOSITORY/deepmod2 detect --bam reads.bam --input INPUT_DIR --model MODEL --file_type FILE_TYPE --threads NUM_THREADS --output MOD_CALLS 34 | ``` 35 | This will give you a per-read prediction text file `MOD_CALLS/output.per_read`, a per-site prediction file `MOD_CALLS/output.per_site`, a per-site prediction file with both strands aggregated `MOD_CALLS/output.per_site.aggregated`, and a methylation annotated BAM file `MOD_CALLS/output.bam`. 36 | 5. Visualize the annotated BAM file produced by DeepMod2 in IGV file. In IGV, select 'Color alignments by' and 'base modifications (5mC)'. The following steps will allow you to open the tagged BAM file in IGV: 37 | 38 | a) If an aligned BAM is given to DeepMod2, you only need to sort and index the DeepMod2 methylation tagged BAM file: 39 | ``` 40 | samtools sort MOD_CALLS/output.bam -o MOD_CALLS/final.bam --write-index 41 | ``` 42 | 43 | b) If an unaligned BAM is given to DeepMod2, first align the DeepMod2 methylation tagged BAM file (while preserving methylation tags MM and ML), then sort and index it: 44 | ``` 45 | samtools fastq MOD_CALLS/output.bam -T MM,ML,mv,ts| minimap2 -ax map-ont ref.fa - -y -t NUM_THREADS |samtools sort -o MOD_CALLS/final.bam --write-index 46 | ``` 47 | 48 |

49 | 50 | 51 | Please refer to [Usage.md](docs/Usage.md) for details on how to use DeepMod2. 52 | 53 | ## Training 54 | For a detailed usage guide for model training, refer to [Training.md](docs/Training.md) and the code under [train](train/). 55 | 56 | ## Signal Plotting 57 | For a detailed usage guide for signal plotting and comparison, refer to the Jupyter notebook [Signal_Plot_Examples.ipynb](plot_utils/Signal_Plot_Examples.ipynb) and the code under [plot_utils](plot_utils/). The code works with signal files in POD5 format and aligned BAM files with move tables. The Jupyter notebook uses the following data: [plot_files.tar.gz](https://github.com/WGLab/DeepMod2/files/14985308/plot_files.tar.gz). 58 | 59 | ## Models 60 | The following models for 5mC detection in CpG motif are provided in the repository. Use `--model MODEL_NAME` to specify a model to use. You only need to provide the name of the model, not the path to it. Each model is compatible with a different Dorado or Guppy basecalling model version. 61 | |Model Architecture|DeepMod2 Model Name|Flowcell
(Sampling Rate)| Compatible Dorado/Guppy Basecalling Model| 62 | |-|-|-|-| 63 | |BiLSTM|**bilstm_r10.4.1_5khz_v5.0**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v5.0.0** 64 | |Transformer|**transformer_r10.4.1_5khz_v5.0**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v5.0.0** 65 | |BiLSTM|**bilstm_r10.4.1_5khz_v4.3**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.3.0** 66 | |Transformer|**transformer_r10.4.1_5khz_v4.3**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.3.0** 67 | |BiLSTM|**bilstm_r10.4.1_4khz_v4.1**|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.1.0** in Dorado
**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup).cfg** in Guppy 6.5.7| 68 | |Transformer|**transformer_r10.4.1_4khz_v4.1**|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.1.0** in Dorado
**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup).cfg** in Guppy 6.5.7| 69 | |BiLSTM|**bilstm_r10.4.1_4khz_v3.5**
(Published in DeepMod2 paper)|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v3.5.2** in Dorado
**dna_r10.4.1_e8.2_400bps_fast\|hac\|sup).cfg** in Guppy 6.3.8| 70 | |Transformer|**transformer_r10.4.1_4khz_v3.5**
(Published in DeepMod2 paper)|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v3.5.2** in Dorado
**dna_r10.4.1_e8.2_400bps_fast\|hac\|sup).cfg** in Guppy 6.3.8| 71 | |BiLSTM|**bilstm_r9.4.1**
(Published in DeepMod2 paper)|R9.4.1 (4kHz)|**dna_r9.4.1_e8_(hac\|sup)@v3.3** in Dorado
**dna_r9.4.1_450bps_(hac\|sup).cfg** in Guppy 6.3.8 and 6.5.7| 72 | |Transformer|**transformer_r9.4.1**
(Published in DeepMod2 paper)|R9.4.1 (4kHz)|**dna_r9.4.1_e8_(hac\|sup)@v3.3** in Dorado
**dna_r9.4.1_450bps_(hac\|sup).cfg** in Guppy 6.3.8 and 6.5.7| 73 | 74 | ## Examples 75 | 76 | Please refer to [Example](https://github.com/WGLab/DeepMod2/blob/main/docs/Example.md) for a complete tutorial on how to run DeepMod2 under various scenarios. A test dataset consisting of a small example of ~60 reads in both POD5 and FAST5 format as well as the expected results in this release:https://github.com/WGLab/DeepMod2/releases/tag/v0.3.0. 77 | -------------------------------------------------------------------------------- /deepmod2: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import time, itertools, torch 4 | 5 | import datetime, os, shutil, argparse, sys, pysam 6 | 7 | from src import utils 8 | 9 | if __name__ == '__main__': 10 | 11 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 12 | parser.add_argument("--print_models", help='Print details of models available', default=False, action='store_true') 13 | main_subparsers = parser.add_subparsers(title="Options", dest="option") 14 | 15 | parent_parser = argparse.ArgumentParser(add_help=False,) 16 | parent_parser.add_argument("--prefix", help='Prefix for the output files',type=str, default='output') 17 | parent_parser.add_argument("--output", help= 'Path to folder where intermediate and final files will be stored, default is current working directory', type=str) 18 | 19 | 20 | 21 | parent_parser.add_argument("--qscore_cutoff", help='Minimum cutoff for mean quality score of a read',type=float, default=0) 22 | parent_parser.add_argument("--length_cutoff", help='Minimum cutoff for read length',type=int, default=0) 23 | 24 | parent_parser.add_argument("--mod_t", help= 'Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for calculation of per-site modification levels.', default=0.5, type=float) 25 | 26 | parent_parser.add_argument("--unmod_t", help= 'Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for calculation of per-site modification levels.', default=0.5, type=float) 27 | 28 | parent_parser.add_argument("--include_non_cpg_ref", help='Include non-CpG reference loci in per-site output where reads have CpG motif.',default=False, action='store_true') 29 | 30 | detect_parser = main_subparsers.add_parser("detect", parents=[parent_parser], 31 | add_help=True, 32 | help="Call methylation from Guppy or Dorado basecalled POD5/FAST5 files using move tables for signal alignment.", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 33 | 34 | detect_required=detect_parser.add_argument_group("Required Arguments") 35 | 36 | detect_parser.add_argument("--motif", help='Motif for detecting modifications followed by zero-based indices of nucleotides within the motif to call modification on. Default is CpG methylation "CG 0". Modification will be called for motif matches on the read and reference unless --reference_motif_only is used to restrict to reference motif matches only. Multiple indices can be specified but they should refer to the same nucleotide letter. The motif and each index listed should be separated by whitespace, e.g. "--motif CGCG 0 2"', nargs='*', default=['CG','0']) 37 | 38 | detect_parser.add_argument("--mod_symbol", help='Symbol to use for modified base in BAM tag MM. Default is to use "m" for 5mC in CG motif, and for other motifs the default is the canonical nucleotide symbol.' , type=str) 39 | 40 | detect_parser.add_argument("--mod_positions", help='A tab separated list of reference coordinates to call modification on. Modifications will only be called on reference positions specified that overlap with the motif, and no modification detection will be performed on other loci. The file shoule have the following format: "contig position strand" on each line. Position should be zero-based and strand should be "+" or "-".', default=None) 41 | 42 | detect_parser.add_argument("--reference_motif_only", help='Restrict modification calling to reference motif matches only', default=False, action='store_true') 43 | 44 | detect_parser.add_argument("--seq_type", help='Specify DNA or direct RNA sequencing.',choices=['dna','rna'], type=str,required=True) 45 | 46 | detect_parser.add_argument("--threads", help='Number of threads to use for processing signal and running model inference. If a GPU is used for inference, then --threads number of threads will be running on GPU concurrently. The total number of threads used by DeepMod2 is equal to --threads plus --bam_threads. It is recommended to run DeepMod2 with mutliple cores, and use at least 4 bam_threads for compressing BAM file.',type=int, default=4) 47 | 48 | detect_parser.add_argument("--ref", help='Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used.', type=str) 49 | 50 | 51 | detect_required.add_argument("--model", help='Name of the model. Recommended model for R9.4.1 flowcells is "bilstm_r9.4.1", for R10.4.1 flowcell (5kHz sampling) it is "bilstm_r10.4.1_5khz_v4.3", and for R10.4.1 flowcell (4kHz sampling) it is "bilstm_r10.4.1_4khz_v4.1". Use --print_models to display all models available along with compatible basecaller models. For custom models, provide the model config file and model checkpoint path separated by comma, e.g. "model.cfg,modelfile"',type=str, required=True) 52 | 53 | detect_required.add_argument("--bam", help='Path to aligned or unaligned BAM file. It is ideal to have move table in BAM file but move table from FAST5 fies can also be used. Aligned BAM file is required for reference anchored methylation calls, otherwise only the motif loci on reads will be called.', type=str, required=True) 54 | 55 | detect_required.add_argument("--file_type", help='Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file.',choices=['fast5','pod5'], type=str,required=True) 56 | 57 | detect_required.add_argument("--input", help='Path to POD5/FAST5 file or folder containing POD5/FAST5 files. If folder provided, then POD5/FAST5 files will be recusrviely searched', type=str, required=True) 58 | 59 | detect_parser.add_argument("--guppy_group", help='Name of the guppy basecall group if move table is in FAST5 file.',type=str, default='Basecall_1D_000') 60 | detect_parser.add_argument("--chrom", nargs='*', help='A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used.') 61 | 62 | detect_parser.add_argument("--fast5_move", help='Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the FAST5 files contains move table.', default=False, action='store_true') 63 | 64 | detect_parser.add_argument("--skip_per_site", help='Skip per site output', default=False, action='store_true') 65 | detect_parser.add_argument("--device", help='Device to use for running pytorch models. you can set --device=cpu for cpu, or --device=cuda for GPU. You can also specify a particular GPU device such as --device=cuda:0 or --device=cuda:1 . If --device paramater is not set by user, then GPU will be used if available otherwise CPU will be used.', type=str) 66 | detect_parser.add_argument("--disable_pruning", help='Disable model pruning (not recommended for CPU inference). By default models are pruned to remove some weights with low L1 norm in linear layers. Pruning has little effect on model accuracy, it can signifcantly improve CPU inference time but not GPU inference time.', default=False, action='store_true') 67 | 68 | detect_parser.add_argument("--exclude_ref_features", help='Exclude reference sequence from feature matrix. By default, if a reference FASTA file is provided via --ref parameter, then the reference sequence is added as a feature for aligned reads, but not if a read is unmapped or if no reference is provided.', default=False, action='store_true') 69 | detect_parser.add_argument("--batch_size", help='Batch size to use for GPU inference. For CPU inference, batch size is fixed at 512.',type=int, default=1024) 70 | 71 | detect_parser.add_argument("--bam_threads", help='Number of threads to use for compressed BAM output. Setting it lower than 4 can significantly lower the runtime.',type=int, default=4) 72 | detect_parser.add_argument("--skip_unmapped", help='Skip unmapped reads from methylation calling. If --chrom is used then unmapped are automatically skipped.', default=False, action='store_true') 73 | 74 | merge_parser = main_subparsers.add_parser("merge", parents=[parent_parser], 75 | add_help=True, 76 | help="Merge per-read calls into per-site calls", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 77 | 78 | merge_parser.add_argument("--input", nargs='*', help= 'List of paths of per-read methylation calls to merge. File paths should be separated by space/whitespace. Use either --input or --list argument, but not both.') 79 | merge_parser.add_argument("--list", help= 'A file containing paths to per-read methylation calls to merge (one per line). Use either --inputs or --list argument, but not both.', type=str) 80 | merge_parser.add_argument("--cpg_output", help= 'Create an additional per-site output file with forward and negative strand counts for CpG sites combined.', default=False, action='store_true') 81 | 82 | if len(sys.argv)==1: 83 | parser.print_help() 84 | parser.exit() 85 | 86 | 87 | elif len(sys.argv)==2: 88 | if sys.argv[1]=='merge': 89 | merge_parser.print_help() 90 | merge_parser.exit() 91 | 92 | elif sys.argv[1]=='detect': 93 | detect_parser.print_help() 94 | detect_parser.exit() 95 | 96 | args = parser.parse_args() 97 | 98 | 99 | if args.print_models: 100 | utils.get_model_help() 101 | parser.exit() 102 | 103 | t=time.time() 104 | 105 | print('%s: Starting DeepMod2.' %str(datetime.datetime.now()), flush=True) 106 | 107 | if not args.output: 108 | args.output=os.getcwd() 109 | 110 | os.makedirs(args.output, exist_ok=True) 111 | 112 | if args.option=='merge': 113 | if args.input: 114 | input_list= args.input 115 | 116 | elif args.list: 117 | with open(args.list,'r') as file_list: 118 | input_list=[x.rstrip('\n') for x in file_list.readlines()] 119 | 120 | params={'output':args.output, 'prefix':args.prefix, 'qscore_cutoff':args.qscore_cutoff, 121 | 'length_cutoff':args.length_cutoff, 'mod_t':args.mod_t, 122 | 'unmod_t':args.unmod_t,'include_non_cpg_ref':args.include_non_cpg_ref, 'cpg_output':args.cpg_output} 123 | 124 | site_pred_file=utils.get_per_site(params, input_list) 125 | 126 | else: 127 | if args.chrom: 128 | args.skip_unmapped=True 129 | chrom_list=args.chrom 130 | else: 131 | chrom_list=pysam.AlignmentFile(args.bam,'rb',check_sq=False).references 132 | 133 | if args.device: 134 | dev=args.device 135 | else: 136 | if torch.cuda.is_available(): 137 | dev = "cuda" 138 | else: 139 | dev = "cpu" 140 | 141 | 142 | motif_seq, exp_motif_seq, motif_ind, valid_motif=utils.motif_check(args.motif) 143 | if not valid_motif: 144 | sys.exit(3) 145 | 146 | params={'input':args.input, 'output':args.output, 'threads':args.threads, 147 | 'prefix':args.prefix, 'model':args.model, 148 | 'qscore_cutoff':args.qscore_cutoff, 'ref':args.ref, 149 | 'length_cutoff':args.length_cutoff, 'bam':args.bam, 150 | 'file_type':args.file_type, 'fast5_move':args.fast5_move, 151 | 'guppy_group':args.guppy_group, 152 | 'mod_t':args.mod_t, 'unmod_t':args.unmod_t, 'include_non_cpg_ref': args.include_non_cpg_ref, 153 | 'skip_per_site':args.skip_per_site, 'chrom_list':chrom_list, "dev":dev, 154 | 'disable_pruning':args.disable_pruning, 'batch_size':args.batch_size, 155 | 'exclude_ref_features':args.exclude_ref_features,'bam_threads':args.bam_threads, 156 | 'skip_unmapped':args.skip_unmapped, 'mod_positions':args.mod_positions, 157 | 'motif_seq':motif_seq, 'motif_ind':motif_ind,'exp_motif_seq':exp_motif_seq, 158 | 'reference_motif_only':args.reference_motif_only, 159 | 'seq_type':args.seq_type, 160 | 'mod_symbol':args.mod_symbol 161 | } 162 | 163 | print('\n%s: \nCommand: python %s\n' %(str(datetime.datetime.now()), ' '.join(sys.argv)), flush=True) 164 | 165 | with open(os.path.join(args.output,'args'),'w') as file: 166 | file.write('Command: python %s\n\n\n' %(' '.join(sys.argv))) 167 | file.write('------Parameters Used For Running DeepMod2------\n') 168 | for k in vars(args): 169 | file.write('{}: {}\n'.format(k,vars(args)[k]) ) 170 | 171 | from src import detect 172 | detect.call_manager(params) 173 | 174 | print('\n%s: Time elapsed=%.4fs' %(str(datetime.datetime.now()),time.time()-t), flush=True) 175 | -------------------------------------------------------------------------------- /docs/Example.md: -------------------------------------------------------------------------------- 1 | # DeepMod2 Run Example 2 | 3 | This example shows how to use DeepMod2 to prediction 5mC methylation from FAST5 files. We will use sample dataset from https://github.com/WGLab/DeepMod2/releases/tag/v0.3.0. 4 | 5 | - [1. Methylation Calling from POD5 files with Dorado basecalling](Example.md#1-methylation-calling-from-pod5-files-with-dorado-basecalling) 6 | - [1.1 Reference Anchored Methylation Calling](Example.md#11-reference-anchored-methylation-calling) 7 | - [1.1.1 Dorado Basecalling and Read Alignment to Reference Genome](Example.md#111-dorado-basecalling-and-read-alignment-to-reference-genome) 8 | - [1.1.2 (Optional) Read Phasing for diploid genomes](Example.md#112-optional-read-phasing-for-diploid-genomes) 9 | - [1.1.3 Methylation Calling with DeepMod2](Example.md#113-methylation-calling-with-deepmod2) 10 | - [1.1.4 Visualizing DeepMod2 Methylation in IGV](Example.md#114-visualizing-deepmod2-methylation-in-igv) 11 | - [1.2 Reference free methylation calling](Example.md#12-reference-free-methylation-calling) 12 | - [1.2.1 Dorado Basecalling ](Example.md#121-dorado-basecalling) 13 | - [1.2.2 Reference Free Methylation Calling with DeepMod2](Example.md#122-reference-free-methylation-calling-with-deepmod2) 14 | - [1.2.3 Optional Read Alignment to Reference Genome and Per-site frequency calculation with modkit](Example.md#123-optional-read-alignment-to-reference-genome-and-per-site-frequency-calculation-with-modkit) 15 | - [2. Methylation Calling from FAST5 files with Guppy basecalling](Example.md#2-methylation-calling-from-fast5-files-with-guppy-basecalling) 16 | - [2.1 Reference Anchored Methylation Calling](Example.md#21-reference-anchored-methylation-calling) 17 | - [2.1.1 Guppy Basecalling and Read Alignment to Reference Genome](Example.md#211-guppy-basecalling-and-read-alignment-to-reference-genome) 18 | - [2.1.2 Methylation Calling with DeepMod2 and Guppy Basecalling](Example.md#212-methylation-calling-with-deepmod2-and-guppy-basecalling) 19 | - [2.2 Reference free methylation calling with Guppy Basecalling](Example.md#22-reference-free-methylation-calling-with-guppy-basecalling) 20 | 21 | # 1. Methylation Calling from POD5 files with Dorado basecalling 22 | **Prepare Directories** 23 | ``` 24 | INPUT_DIR=data 25 | OUTPUT_DIR=mod 26 | 27 | mkdir -p ${INPUT_DIR} 28 | mkdir -p ${OUTPUT_DIR} 29 | ``` 30 | **Download Software Packges** 31 | ``` 32 | # Install DeepMod2 33 | git clone https://github.com/WGLab/DeepMod2.git ${INPUT_DIR}/DeepMod2 34 | conda env create -f ${INPUT_DIR}/DeepMod2/environment.yml 35 | conda activate deepmod2 36 | conda install samtools minimap2 bedtools -y 37 | 38 | # Download Dorado Basecaller and model 39 | wget -qO- https://cdn.oxfordnanoportal.com/software/analysis/dorado-0.5.3-linux-x64.tar.gz | tar xzf - -C ${INPUT_DIR} 40 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado download --model dna_r10.4.1_e8.2_400bps_hac@v4.3.0 --directory ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/ 41 | ``` 42 | 43 | **Download Nanopore data and reference genome** 44 | ``` 45 | # Download reference genome 46 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.gz -O -| gunzip -c > ${INPUT_DIR}/GRCh38.fa 47 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.fai -O ${INPUT_DIR}/GRCh38.fa.fai 48 | 49 | # Download POD5 files 50 | mkdir -p ${INPUT_DIR}/nanopore_raw_data 51 | 52 | wget -qO- https://github.com/WGLab/DeepMod2/files/14368872/sample.pod5.tar.gz| tar xzf - -C ${INPUT_DIR}/nanopore_raw_data 53 | ``` 54 | 55 | ## 1.1 Reference Anchored Methylation Calling 56 | In order to perform reference anchored methylation calling, we will provide an aligned BAM to DeepMod2. In this case, DeepMod2 will detect 5mC in all CpG motifs found on the read, as wel as any bases of the read that map to a reference CpG site. Finally, it will combine per-read predictions for a given reference CpG site into a per-site methylation frequency. Any unaligned reads or unaligned segments of the reads will also be analyzed for 5mC and reported in per-read output and BAM file, but would not be used in per-site frequency calculation. 57 | 58 | ### 1.1.1 Dorado Basecalling and Read Alignment to Reference Genome 59 | First we will perform basecalling of our nanopore signal file using Dorado basecaller. It is possible to align the reads during basecalling or align the reads after basecalling. Both options are shown below. Since we need move table for our basecalled DNA sequences, we will use `--emit-moves` while running Dorado, which will produce an aligned (Option A) or unaligned BAM file (Option B). 60 | 61 | #### Option A: Perform Read Alignment during Bascalling with Dorado 62 | 63 | Dorado has the option to perform read alignment using minimap2 during basecalling if a reference FASTA file is provided as `--reference` option. This can be be helpful in reducing the number of steps needed to run. 64 | 65 | ``` 66 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado basecaller --emit-moves --recursive --reference ${INPUT_DIR}/GRCh38.fa ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/dna_r10.4.1_e8.2_400bps_hac@v4.3.0 ${INPUT_DIR}/nanopore_raw_data > ${OUTPUT_DIR}/aligned.bam 67 | ``` 68 | 69 | This will produce an aligned BAM file named `aligned.bam` under the `$OUTPUT_DIR` folder. 70 | 71 | #### Option B: Perform Read Alignment after Bascalling with Dorado 72 | 73 | It is possible to run Dorado basecaller without performing alignment. This can be helpful in speeding up basecalling process that requires the use of a GPU instance which can be expensive. It also allows you more flexibility in terms of how you want to perform alignment, with specific minimap2 parameters. 74 | 75 | **Basecalling with Dorado** 76 | 77 | ``` 78 | # Perform basecalling 79 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado basecaller --emit-moves --recursive ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/dna_r10.4.1_e8.2_400bps_hac@v4.3.0 ${INPUT_DIR}/nanopore_raw_data > ${OUTPUT_DIR}/basecalled.bam 80 | ``` 81 | 82 | This will produce an unaligned BAM file named `basecalled.bam` under the `$OUTPUT_DIR` folder. 83 | 84 | **Alignment with minimap2** 85 | 86 | We will convert this BAM file into FASTQ format while keeping all the tags and pipe into minimap2 for alignment. 87 | 88 | ``` 89 | # Align using minimap2 while copying move table information 90 | samtools fastq ${OUTPUT_DIR}/basecalled.bam -T "*"|minimap2 -ax map-ont ${INPUT_DIR}/GRCh38.fa - -y|samtools view -o ${OUTPUT_DIR}/aligned.bam 91 | ``` 92 | 93 | This will produce an aligned BAM file named `aligned.bam` under the `$OUTPUT_DIR` folder. 94 | 95 | ### 1.1.2 (Optional) Read Phasing for diploid genomes 96 | You can optionally use SNP calling and haplotyping tool such as NanoCaller or Clair3 to phase the BAM file into parental haplotypes. The phased BAM file can be provided as input to DeepMod2 instead of `${OUTPUT_DIR}/aligned.bam` to get haplotype specific methylation calls. 97 | 98 | ``` 99 | #install NanoCaller 100 | conda install -c bioconda NanoCaller 101 | 102 | #sort and index the BAM file 103 | samtools sort ${OUTPUT_DIR}/aligned.bam -o ${OUTPUT_DIR}/aligned.sorted.bam 104 | samtools index ${OUTPUT_DIR}/aligned.sorted.bam 105 | 106 | #Run NanoCaller to phase the reads 107 | NanoCaller --bam ${OUTPUT_DIR}/aligned.sorted.bam --ref ${INPUT_DIR}/GRCh38.fa --mode snps --phase --output ${OUTPUT_DIR}/nanocaller --wgs_contigs chr1-22XY --cpu 8 108 | 109 | # Merge phased reads into a single BAM file 110 | find ${OUTPUT_DIR}/nanocaller/intermediate_phase_files -type f -name '*bam'|samtools cat -b - -o ${OUTPUT_DIR}/phased.bam 111 | 112 | ``` 113 | 114 | ### 1.1.3 Methylation Calling with DeepMod2 115 | Now we will run DeepMod2's `detect` module using `bilstm_r10.4.1_5khz_v4.3` model and use the aligned BAM file and Nanopore signal files as input. Since we want to perform reference anchored methylation calling, we will provide the reference genome FASTA file as input as well. We will use the phased BAM file from the previous step, but you can also use `${OUTPUT_DIR}/aligned.bam` BAM file if you do not want to get haplotype specific methylation calls. 116 | 117 | ``` 118 | # Run DeepMod2 119 | BAM_INPUT=${OUTPUT_DIR}/phased.bam # Use ${OUTPUT_DIR}/aligned.bam if you did not use NanoCaller to phase the reads 120 | python ${INPUT_DIR}/DeepMod2/deepmod2 detect --model bilstm_r10.4.1_5khz_v4.3 --file_type pod5 --bam $BAM_INPUT --input ${INPUT_DIR}/nanopore_raw_data --output ${OUTPUT_DIR}/deepmod2/ --ref ${INPUT_DIR}/GRCh38.fa --threads 8 121 | ``` 122 | The output folder of DeepMod2 `${OUTPUT_DIR}/deepmod2/` will contain the following files: 123 | ``` 124 | args -> Shows the arguments and command use to run DeepMod2 125 | output.bam -> Unsorted methylation tagged BAM file 126 | output.per_read -> Per-read methylation calls in sorted BED file 127 | output.per_site -> Per-site methylation calls for +- strands separately in sorted BED file. 128 | output.per_site.aggregated -> Per-site methylation calls for with counts for +- strands combined. 129 | ``` 130 | 131 | **Per-read Output** 132 | We will inspect contents of the per-read output file using `head ${OUTPUT_DIR}/deepmod2/output.per_read`: 133 | 134 | 135 | |read_name|chromosome|ref_position_before|ref_position|read_position|strand|methylation_score|mean_read_qscore|read_length|read_phase|ref_cpg| 136 | |-|-|-|-|-|-|-|-|-|-|-| 137 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733569|2733570|35|-|0.0036|18.05|48187|1|TRUE| 138 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733457|2733458|146|-|0.9846|18.05|48187|1|TRUE| 139 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733439|2733440|164|-|0.0048|18.05|48187|1|TRUE| 140 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733362|2733363|242|-|0.2352|18.05|48187|1|TRUE| 141 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733356|2733357|248|-|0.841|18.05|48187|1|TRUE| 142 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733351|2733352|253|-|0.9893|18.05|48187|1|FALSE| 143 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733341|2733342|263|-|0.012|18.05|48187|1|TRUE| 144 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733155|2733156|449|-|0.9858|18.05|48187|1|TRUE| 145 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733143|2733144|461|-|0.966|18.05|48187|1|TRUE| 146 | 147 | 148 | **Per-Site Output** 149 | 150 | We will use `bedtools intersect` to inspect per-site methylation frequencies in chr11:2699000-2702000 imprinting control region from the per-site output file with stranded CpG counts: 151 | 152 | ``` 153 | printf 'chr11\t2699000\t2702000'|bedtools intersect -header -a ${OUTPUT_DIR}/deepmod2/output.per_site -b -|head 154 | ``` 155 | 156 | |#chromosome|position_before|position|strand|ref_cpg|coverage|mod_coverage|unmod_coverage|mod_percentage|coverage_phase1|mod_coverage_phase1|unmod_coverage_phase1|mod_percentage_phase1|coverage_phase2|mod_coverage_phase2|unmod_coverage_phase2|mod_percentage_phase2| 157 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 158 | |chr11|2699031|2699032|+|TRUE|16|8|8|0.5|8|0|8|0|8|8|0|1| 159 | |chr11|2699032|2699033|-|TRUE|11|7|4|0.6364|4|0|4|0|7|7|0|1| 160 | |chr11|2699037|2699038|+|TRUE|16|8|8|0.5|8|0|8|0|8|8|0|1| 161 | |chr11|2699038|2699039|-|TRUE|11|7|4|0.6364|4|0|4|0|7|7|0|1| 162 | |chr11|2699048|2699049|+|TRUE|16|6|10|0.375|8|0|8|0|8|6|2|0.75| 163 | |chr11|2699049|2699050|-|TRUE|11|5|6|0.4545|4|0|4|0|7|5|2|0.7143| 164 | |chr11|2699099|2699100|+|TRUE|15|7|8|0.4667|8|0|8|0|7|7|0|1| 165 | |chr11|2699100|2699101|-|TRUE|11|7|4|0.6364|4|0|4|0|7|7|0|1| 166 | |chr11|2699101|2699102|+|TRUE|16|7|9|0.4375|8|0|8|0|8|7|1|0.875| 167 | 168 | 169 | **Aggregated Per-Site Output** 170 | 171 | We will use `bedtools intersect` to inspect per-site methylation frequencies in chr11:2699000-2702000 imprinting control region from the per-site output file with aggregated CpG counts over +- strands: 172 | 173 | ``` 174 | printf 'chr11\t2699000\t2702000'|bedtools intersect -header -a ${OUTPUT_DIR}/deepmod2/output.per_site.aggregated -b -|head 175 | ``` 176 | 177 | |#chromosome|position_before|position|ref_cpg|coverage|mod_coverage|unmod_coverage|mod_percentage|coverage_phase1|mod_coverage_phase1|unmod_coverage_phase1|mod_percentage_phase1|coverage_phase2|mod_coverage_phase2|unmod_coverage_phase2|mod_percentage_phase2| 178 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 179 | |chr11|2699031|2699033|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1| 180 | |chr11|2699037|2699039|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1| 181 | |chr11|2699048|2699050|TRUE|27|11|16|0.4074|12|0|12|0|15|11|4|0.7333| 182 | |chr11|2699099|2699101|TRUE|26|14|12|0.5385|12|0|12|0|14|14|0|1| 183 | |chr11|2699101|2699103|TRUE|27|14|13|0.5185|12|0|12|0|15|14|1|0.9333| 184 | |chr11|2699115|2699117|TRUE|27|14|13|0.5185|12|0|12|0|15|14|1|0.9333| 185 | |chr11|2699120|2699122|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1| 186 | |chr11|2699180|2699182|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1| 187 | |chr11|2699187|2699189|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1| 188 | 189 | These results show that phase 1 is completely unmodified (column mod_percentage_phase1) whereas phase 2 is nearly completely modified (mod_percentage_phase2), which is what we expect for this imprinted region. 190 | 191 | 192 | ### 1.1.4 Visualizing DeepMod2 Methylation in IGV 193 | Since the methylation tagged BAM file produced by DeepMod2 is not sorted, we will first sort and index it: 194 | 195 | ``` 196 | samtools sort ${OUTPUT_DIR}/deepmod2/output.bam -o ${OUTPUT_DIR}/deepmod2/output.sorted.bam --write-index 197 | ``` 198 | 199 | Open the BAM file `${OUTPUT_DIR}/deepmod2/output.sorted.bam` in IGV, select `Color alignments by base modificaition (5mC)`. If you used phased BAM file for methylation, you can select `Group alignments by phase` to separate reads by haplotype. Go to the region `chr11:2699000-2702000` to see the following methylation tags: 200 | 201 | ![image](https://github.com/WGLab/DeepMod2/assets/35819083/b7e87a6c-9dda-4b31-be0e-93c13ecec1fb) 202 | 203 | 204 | ## 1.2 Reference free methylation calling 205 | In order to perform reference free methylation calling, we will provide an unaligned BAM to DeepMod2. In this case, DeepMod2 will detect 5mC in all CpG motifs found on the read only and will not use reference sequence as a feature. In this case, per-read predictions will not be combined into a per-site methylation frequency, and methylation will be reported only in per-read output and BAM file. 206 | 207 | ### 1.2.1 Dorado Basecalling 208 | 209 | ``` 210 | # Perform basecalling 211 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado basecaller --emit-moves --recursive ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/dna_r10.4.1_e8.2_400bps_hac@v4.3.0 ${INPUT_DIR}/nanopore_raw_data > ${OUTPUT_DIR}/basecalled.bam 212 | ``` 213 | 214 | This will produce an unaligned BAM file named `basecalled.bam` under the `$OUTPUT_DIR` folder. 215 | 216 | ### 1.2.2 Reference Free Methylation Calling with DeepMod2 217 | 218 | Now we will run DeepMod2's `detect` module using `bilstm_r10.4.1_5khz_v4.3` model and use the unaligned BAM file and Nanopore signal files as input. In this situation, we will not provide the reference genome FASTA file as input. 219 | 220 | ``` 221 | # Run DeepMod2 222 | BAM_INPUT=${OUTPUT_DIR}/basecalled.bam 223 | python ${INPUT_DIR}/DeepMod2/deepmod2 detect --model bilstm_r10.4.1_5khz_v4.3 --file_type pod5 --bam $BAM_INPUT --input ${INPUT_DIR}/nanopore_raw_data --output ${OUTPUT_DIR}/deepmod2/ --threads 8 224 | ``` 225 | 226 | The output folder of DeepMod2 `${OUTPUT_DIR}/deepmod2/` will contain the following files: 227 | ``` 228 | args -> Shows the arguments and command use to run DeepMod2 229 | output.bam -> Unsorted methylation tagged BAM file 230 | output.per_read -> Per-read methylation calls in sorted BED file 231 | output.per_site -> Per-site methylation file will be empty. 232 | output.per_site.aggregated -> Aggregated Per-site methylation file will be empty. 233 | ``` 234 | 235 | **Per-read Output** 236 | We will inspect contents of the per-read output file using `head ${OUTPUT_DIR}/deepmod2/output.per_read`: 237 | 238 | 239 | |read_name|chromosome|ref_position_before|ref_position|read_position|strand|methylation_score|mean_read_qscore|read_length|read_phase|ref_cpg| 240 | |-|-|-|-|-|-|-|-|-|-|-| 241 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|35|+|0.0042|18.12|48175|0|FALSE| 242 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|146|+|0.985|18.12|48175|0|FALSE| 243 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|164|+|0.0057|18.12|48175|0|FALSE| 244 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|242|+|0.36|18.12|48175|0|FALSE| 245 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|248|+|0.9428|18.12|48175|0|FALSE| 246 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|253|+|0.9935|18.12|48175|0|FALSE| 247 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|263|+|0.0123|18.12|48175|0|FALSE| 248 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|449|+|0.9866|18.12|48175|0|FALSE| 249 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|461|+|0.968|18.12|48175|0|FALSE| 250 | 251 | 252 | In this case the chromosome, position and strand information is not available since the reads were unaligned. 253 | 254 | ### 1.2.3 Optional Read Alignment to Reference Genome and Per-site frequency calculation with modkit 255 | The unaligned BAM file produced by DeepMod2 contains 5mC tags MM and ML for all CpG motifs in a read. Reads from this BAM file can be aligned to any reference genome and methylation counts for reads mapping to the same reference CpG sites can be obtained using [modkit](https://github.com/nanoporetech/modkit). 256 | 257 | ``` 258 | # Align reads using minimap2 and then sort and index the BAM file 259 | samtools fastq ${OUTPUT_DIR}/deepmod2/output.bam -T "*"|minimap2 -ax map-ont ${INPUT_DIR}/GRCh38.fa - -y|samtools sort -o ${OUTPUT_DIR}/deepmod2/aligned.sorted.bam --write-index 260 | 261 | modkit pileup ${OUTPUT_DIR}/deepmod2/aligned.sorted.bam ${OUTPUT_DIR}/deepmod2/per_site_methylation.bed --ref ${INPUT_DIR}/GRCh38.fa --preset traditional 262 | ``` 263 | 264 | Per-site frequency calculation of modkit will be stored in `${OUTPUT_DIR}/deepmod2/per_site_methylation.bed` file in the column format described [here](https://github.com/nanoporetech/modkit?tab=readme-ov-file#bedmethyl-column-descriptions). We will use `bedtools intersect` to inspect per-site methylation frequencies in chr11:2699000-2702000 imprinting control region from the per-site output file with aggregated CpG counts over +- strands: 265 | 266 | ``` 267 | printf 'chr11\t2699900\t27002000'|bedtools intersect -header -a ${OUTPUT_DIR}/deepmod2/per_site_methylation.bed -b -|head 268 | ``` 269 | 270 | |chrom|start|end|code score|score|strand|start|end|color|N_valid_cov|fraction_modified|N_mod|N_canonical|N_other_mod|N_delete|N_fail|N_diff|N_nocall| 271 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| 272 | |chr11|2699906|2699907|m|25|.|2699906|2699907|255,0,0|25|60|15|10|0|0|0|1|2| 273 | |chr11|2699915|2699916|m|27|.|2699915|2699916|255,0,0|27|55.56|15|12|0|0|0|0|1| 274 | |chr11|2699917|2699918|m|28|.|2699917|2699918|255,0,0|28|57.14|16|12|0|0|0|0|0| 275 | |chr11|2699927|2699928|m|27|.|2699927|2699928|255,0,0|27|55.56|15|12|0|0|0|0|1| 276 | |chr11|2699966|2699967|m|28|.|2699966|2699967|255,0,0|28|57.14|16|12|0|0|0|0|0| 277 | |chr11|2699976|2699977|m|27|.|2699976|2699977|255,0,0|27|55.56|15|12|0|0|1|0|0| 278 | |chr11|2699979|2699980|m|28|.|2699979|2699980|255,0,0|28|57.14|16|12|0|0|0|0|0| 279 | |chr11|2699981|2699982|m|28|.|2699981|2699982|255,0,0|28|57.14|16|12|0|0|0|0|0| 280 | |chr11|2699984|2699985|m|27|.|2699984|2699985|255,0,0|27|55.56|15|12|0|0|1|0|0| 281 | |chr11|2699988|2699989|m|27|.|2699988|2699989|255,0,0|27|55.56|15|12|0|0|1|0|0| 282 | 283 | Even though this result does not show haplotype specific methylation, we can see that the modified fraction is ~50-60% as expected. 284 | 285 | 286 | # 2. Methylation Calling from FAST5 files with Guppy basecalling 287 | **Prepare Directories** 288 | ``` 289 | INPUT_DIR=data 290 | OUTPUT_DIR=mod 291 | 292 | mkdir -p ${INPUT_DIR} 293 | mkdir -p ${OUTPUT_DIR} 294 | ``` 295 | **Download Software Packges** 296 | ``` 297 | # Install DeepMod2 298 | git clone https://github.com/WGLab/DeepMod2.git ${INPUT_DIR}/DeepMod2 299 | conda env create -f ${INPUT_DIR}/DeepMod2/environment.yml 300 | conda activate deepmod2 301 | conda install samtools minimap2 bedtools -y 302 | 303 | # Download Guppy basecaller 304 | wget -qO- https://cdn.oxfordnanoportal.com/software/analysis/ont-guppy_6.5.7_linux64.tar.gz| tar xzf - -C ${INPUT_DIR} 305 | ``` 306 | 307 | **Download Nanopore data and reference genome** 308 | ``` 309 | # Download reference genome 310 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.gz -O -| gunzip -c > ${INPUT_DIR}/GRCh38.fa 311 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.fai -O ${INPUT_DIR}/GRCh38.fa.fai 312 | 313 | # Download FAST5 files 314 | mkdir -p ${INPUT_DIR}/nanopore_raw_data 315 | 316 | wget -qO- https://github.com/WGLab/DeepMod2/files/14368873/sample.fast5.tar.gz| tar xzf - -C ${INPUT_DIR}/nanopore_raw_data 317 | ``` 318 | 319 | ## 2.1 Reference Anchored Methylation Calling 320 | In order to perform reference anchored methylation calling, we will provide an aligned BAM to DeepMod2. In this case, DeepMod2 will detect 5mC in all CpG motifs found on the read, as wel as any bases of the read that map to a reference CpG site. Finally, it will combine per-read predictions for a given reference CpG site into a per-site methylation frequency. Any unaligned reads or unaligned segments of the reads will also be analyzed for 5mC and reported in per-read output and BAM file, but would not be used in per-site frequency calculation. 321 | 322 | ### 2.1.1 Guppy Basecalling and Read Alignment to Reference Genome 323 | First we will perform basecalling of our nanopore signal file using Guppy basecaller. It is possible to align the reads during basecalling or align the reads after basecalling. Both options are shown below. Since we need move table for our basecalled DNA sequences, we will use `--moves_out` while running Guppy, which will produce an aligned (Option A) or unaligned BAM file (Option B). 324 | 325 | #### Option A: Perform Read Alignment during Bascalling with Guppy 326 | 327 | Guppy has the option to perform read alignment using minimap2 during basecalling if a reference FASTA file is provided as `--align_ref` option. This can be be helpful in reducing the number of steps needed to run. 328 | 329 | ``` 330 | ${INPUT_DIR}/ont-guppy/bin/guppy_basecaller -i ${INPUT_DIR}/nanopore_raw_data -s ${OUTPUT_DIR}/basecalls --align_ref ${INPUT_DIR}/GRCh38.fa -c dna_r10.4.1_e8.2_400bps_5khz_hac.cfg --bam_out --moves_out --recursive 331 | ``` 332 | 333 | This will produce several aligned BAM files under the `${OUTPUT_DIR}/basecalls` subfolders `pass` and `fail`. We will combine these BAM files into a single BAM file to give as input to DeepMod2, and we have a choice of using both pass and fail reads or just pass reads. 334 | 335 | ``` 336 | find ${OUTPUT_DIR}/basecalls \( -path "*/pass/*" -o -path "*/fail/*" \) -type f -name "*.bam"|samtools cat -b - -o ${OUTPUT_DIR}/aligned.bam 337 | ``` 338 | 339 | #### Option B: Perform Read Alignment after Bascalling with Guppy 340 | 341 | It is possible to run Guppy basecaller without performing alignment. This can be helpful in speeding up basecalling process that requires the use of a GPU instance which can be expensive. It also allows you more flexibility in terms of how you want to perform alignment, with specific minimap2 parameters. 342 | 343 | **Basecalling with Guppy** 344 | 345 | ``` 346 | # Perform basecalling 347 | ${INPUT_DIR}/ont-guppy/bin/guppy_basecaller -i ${INPUT_DIR}/nanopore_raw_data -s ${OUTPUT_DIR}/basecalls -c dna_r10.4.1_e8.2_400bps_5khz_hac.cfg --bam_out --moves_out --recursive 348 | ``` 349 | 350 | This will produce several aligned BAM files under the `${OUTPUT_DIR}/basecalls` subfolders `pass` and `fail`. We will align read sequences from these BAM files and we have a choice of using both pass and fail reads or just pass reads. 351 | 352 | **Alignment with minimap2** 353 | 354 | We will convert this BAM file into FASTQ format while keeping all the tags and pipe into minimap2 for alignment. 355 | 356 | ``` 357 | # Align using minimap2 while copying move table information 358 | 359 | find ${OUTPUT_DIR}/basecalls \( -path "*/pass/*" -o -path "*/fail/*" \) -type f -name "*.bam"|samtools cat -b -|samtools fastq - -T "*"|minimap2 -ax map-ont ${INPUT_DIR}/GRCh38.fa - -y|samtools view -o ${OUTPUT_DIR}/aligned.bam 360 | ``` 361 | 362 | This will produce an aligned BAM file named `aligned.bam` under the `$OUTPUT_DIR` folder. 363 | 364 | --------- 365 | ### 2.1.2 Methylation Calling with DeepMod2 and Guppy Basecalling 366 | Once you have an aligned BAM file, you can continue with DeepMod2 methylation calling as described in [1.1.3 Methylation Calling with DeepMod2](Example.md#113-methylation-calling-with-deepmod2), you would only need to change `--file_type` parameter to `fast5` since we are using FAST5 file format. 367 | 368 | ## 2.2 Reference free methylation calling with Guppy Basecalling 369 | You can perform Guppy basecalling without reference alignment: 370 | 371 | ``` 372 | # Perform basecalling 373 | ${INPUT_DIR}/ont-guppy/bin/guppy_basecaller -i ${INPUT_DIR}/nanopore_raw_data -s ${OUTPUT_DIR}/basecalls -c dna_r10.4.1_e8.2_400bps_5khz_hac.cfg --bam_out --moves_out --recursive 374 | ``` 375 | 376 | and continue with methylation calling as described in [1.2.2 Reference Free Methylation Calling with DeepMod2](Example.md#122-reference-free-methylation-calling-with-deepmod2). 377 | -------------------------------------------------------------------------------- /docs/Install.md: -------------------------------------------------------------------------------- 1 | ## Conda Installation 2 | First, install Miniconda, a minimal installation of Anaconda, which is much smaller and has a faster installation: 3 | ``` 4 | curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 5 | bash Miniconda3-latest-Linux-x86_64.sh 6 | ``` 7 | Go through all the prompts (installation in `$HOME` is recommended). After Anaconda is installed successfully, simply run: 8 | ``` 9 | git clone https://github.com/WGLab/DeepMod2.git 10 | conda env create -f DeepMod2/environment.yml 11 | conda activate deepmod2 12 | ``` 13 | 14 | After installing, run `python DeepMod2/deepmod2 --help` to see the run options. 15 | 16 | ### Installation for GPU 17 | If you want to use GPU to accelerate DeepMod2, make sure to install cuda enabled version of pytorch. Details for GPU accelerated pytorch can be found here https://pytorch.org/get-started/locally/ and you can select the installation command best suited to your system. One xample of installation command is shown below for CUDA 11.8: 18 | 19 | `conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia` 20 | -------------------------------------------------------------------------------- /docs/Usage.md: -------------------------------------------------------------------------------- 1 | # DeepMod2 Usage 2 | 3 | 4 | DeepMod2 has two run options `detect` and `merge` which can be shown using `python deepmod2 --help` 5 | ``` 6 | usage: deepmod2 [-h] [--print_models] {detect,merge} ... 7 | 8 | options: 9 | -h, --help show this help message and exit 10 | --print_models Print details of models available (default: False) 11 | 12 | Options: 13 | {detect,merge} 14 | detect Call methylation from Guppy or Dorado basecalled POD5/FAST5 files using move tables for signal alignment. 15 | merge Merge per-read calls into per-site calls 16 | 17 | ``` 18 | 19 | In short, `detect` is for methylation calling using Dorado and Guppy generated POD5/FAST5 files, respectively, whereas `merge` is for merging per-read methylation calls from several runs into per-site methylation calls. 20 | 21 | `python deepmod2 --print_models` shows names of available models as well as their description, such as datasets used for training. Names of these models can be used with `--model` argument to specify which model to use. Please use the appropriate DeepMod2 model for the flowcell you are using, e.g. R9.4.1 vs R10.4.1. 22 | 23 | ## detect-guppy 24 | DeepMod2 can detect 5mC methylation from Dorado and Guppy basecalled POD5/FAST5 signal files using `detect` option. In addition to signal files, you are required to provide a BAM file containing read sequences. It is recommended that the reads in the BAM file are aligned to a reference genome for more accurate results, in which case you should also provide a reference genome FASTA file. You are required to provide move tables for the signal files, so make sure you output move tables when you run the basecaller. The move table information can be part of the BAM file as mv and ts tags, or it can be part of FAST5 file (which is deprecated now). 25 | 26 | ``` 27 | usage: deepmod2 detect [-h] [--prefix PREFIX] [--output OUTPUT] [--qscore_cutoff QSCORE_CUTOFF] [--length_cutoff LENGTH_CUTOFF] [--mod_t MOD_T] [--unmod_t UNMOD_T] [--include_non_cpg_ref] 28 | [--threads THREADS] [--ref REF] --model MODEL --bam BAM --file_type {fast5,pod5} --input INPUT [--guppy_group GUPPY_GROUP] [--chrom [CHROM ...]] [--fast5_move] 29 | [--skip_per_site] [--device DEVICE] [--disable_pruning] [--exclude_ref_features] [--batch_size BATCH_SIZE] [--bam_threads BAM_THREADS] [--skip_unmapped] 30 | 31 | options: 32 | -h, --help show this help message and exit 33 | --prefix PREFIX Prefix for the output files (default: output) 34 | --output OUTPUT Path to folder where intermediate and final files will be stored, default is current working directory (default: None) 35 | --qscore_cutoff QSCORE_CUTOFF 36 | Minimum cutoff for mean quality score of a read (default: 0) 37 | --length_cutoff LENGTH_CUTOFF 38 | Minimum cutoff for read length (default: 0) 39 | --mod_t MOD_T Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for 40 | calculation of per-site modification levels. (default: 0.5) 41 | --unmod_t UNMOD_T Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for 42 | calculation of per-site modification levels. (default: 0.5) 43 | --include_non_cpg_ref 44 | Include non-CpG reference loci in per-site output where reads have CpG motif. (default: False) 45 | --threads THREADS Number of threads to use for processing signal and running model inference. If a GPU is used for inference, then --threads number of threads will be running on GPU 46 | concurrently. The total number of threads used by DeepMod2 is equal to --threads plus --bam_threads. It is recommended to run DeepMod2 with mutliple cores, and use 47 | at least 4 bam_threads for compressing BAM file. (default: 4) 48 | --ref REF Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used. (default: None) 49 | --guppy_group GUPPY_GROUP 50 | Name of the guppy basecall group (default: Basecall_1D_000) 51 | --chrom [CHROM ...] A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used. (default: None) 52 | --fast5_move Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the 53 | FAST5 files contains move table. (default: False) 54 | --skip_per_site Skip per site output (default: False) 55 | --device DEVICE Device to use for running pytorch models. you can set --device=cpu for cpu, or --device=cuda for GPU. You can also specify a particular GPU device such as 56 | --device=cuda:0 or --device=cuda:1 . If --device paramater is not set by user, then GPU will be used if available otherwise CPU will be used. (default: None) 57 | --disable_pruning Disable model pruning (not recommended for CPU inference). By default models are pruned to remove some weights with low L1 norm in linear layers. Pruning has 58 | little effect on model accuracy, it can signifcantly improve CPU inference time but not GPU inference time. (default: False) 59 | --exclude_ref_features 60 | Exclude reference sequence from feature matrix. By default, if a reference FASTA file is provided via --ref parameter, then the reference sequence is added as a 61 | feature for aligned reads, but not if a read is unmapped or if no reference is provided. (default: False) 62 | --batch_size BATCH_SIZE 63 | Batch size to use for GPU inference. For CPU inference, batch size is fixed at 512. (default: 1024) 64 | --bam_threads BAM_THREADS 65 | Number of threads to use for compressed BAM output. Setting it lower than 4 can significantly lower the runtime. (default: 4) 66 | --skip_unmapped Skip unmapped reads from methylation calling (default: False) 67 | 68 | Required Arguments: 69 | --model MODEL Name of the model. Recommended model for R9.4.1 flowcells is "bilstm_r9.4.1", for R10.4.1 flowcell (4kHz sampling) it is "bilstm_r10.4.1_4khz". Use --print_models 70 | to display all models available. (default: None) 71 | --bam BAM Path to aligned or unaligned BAM file. It is ideal to have move table in BAM file but move table from FAST5 fies can also be used. Aligned BAM file is required for 72 | reference anchored methylation calls, otherwise only the motif loci on reads will be called. (default: None) 73 | --file_type {fast5,pod5} 74 | Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file. (default: None) 75 | --input INPUT Path to POD5/FAST5 file or folder containing POD5/FAST5 files. If folder provided, then POD5/FAST5 files will be recusrviely searched (default: None) 76 | 77 | ``` 78 | 79 | ## merge 80 | For large datasets, it can be useful to split the data and run multiple instances of DeepMod2 for speedup. In this case, you can run `detect` with `--skip_per_site` option and use `deepmod2 merge` to merge per-read calls into per-site calls. 81 | 82 | ``` 83 | usage: deepmod2 merge [-h] [--prefix PREFIX] [--output OUTPUT] [--qscore_cutoff QSCORE_CUTOFF] [--length_cutoff LENGTH_CUTOFF] [--mod_t MOD_T] [--unmod_t UNMOD_T] [--include_non_cpg_ref] 84 | [--input [INPUT ...]] [--list LIST] 85 | 86 | options: 87 | -h, --help show this help message and exit 88 | --prefix PREFIX Prefix for the output files (default: output) 89 | --output OUTPUT Path to folder where intermediate and final files will be stored, default is current working directory (default: None) 90 | --qscore_cutoff QSCORE_CUTOFF 91 | Minimum cutoff for mean quality score of a read (default: 0) 92 | --length_cutoff LENGTH_CUTOFF 93 | Minimum cutoff for read length (default: 0) 94 | --mod_t MOD_T Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for 95 | calculation of per-site modification levels. (default: 0.5) 96 | --unmod_t UNMOD_T Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for 97 | calculation of per-site modification levels. (default: 0.5) 98 | --include_non_cpg_ref 99 | Include non-CpG reference loci in per-site output where reads have CpG motif. (default: False) 100 | --input [INPUT ...] List of paths of per-read methylation calls to merge. File paths should be separated by space/whitespace. Use either --input or --list argument, but not both. 101 | (default: None) 102 | --list LIST A file containing paths to per-read methylation calls to merge (one per line). Use either --inputs or --list argument, but not both. (default: None) 103 | ``` -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: deepmod2 2 | channels: 3 | - conda-forge 4 | - bioconda 5 | - defaults 6 | dependencies: 7 | - numpy 8 | - pysam 9 | - python>=3.8 10 | - pytorch 11 | - h5py 12 | - tqdm 13 | - numba 14 | - ont-fast5-api 15 | - pip 16 | - pip: 17 | - pod5 18 | -------------------------------------------------------------------------------- /plot_utils/README.md: -------------------------------------------------------------------------------- 1 | # Signal Plotting 2 | 3 | The Jupyter notebook [Signal_Plot_Examples.ipynb](Signal_Plot_Examples.ipynb) uses the following data: [plot_files.tar.gz](https://github.com/WGLab/DeepMod2/files/14985308/plot_files.tar.gz) 4 | -------------------------------------------------------------------------------- /plot_utils/plot.py: -------------------------------------------------------------------------------- 1 | import pod5 as p5 2 | import numpy as np 3 | from pathlib import Path 4 | import os, pysam 5 | import pandas as pd 6 | from scipy import stats 7 | import matplotlib.pyplot as plt 8 | import matplotlib as mpl 9 | 10 | import matplotlib.patches as mpatches 11 | from matplotlib.lines import Line2D 12 | import plotly.graph_objects as go 13 | 14 | comp_base_map={'A':'T','T':'A','C':'G','G':'C'} 15 | base_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3} 16 | rev_base_map={0:'A', 1:'C', 2:'G', 3:'T'} 17 | strand_map={'+':0, '-':1} 18 | 19 | def get_x_and_y_axes(split_signal): 20 | x_axes = [] 21 | y_axes = [] 22 | x_axis = np.array([]) 23 | y_axis = [] 24 | 25 | for i, base_signal in enumerate(split_signal): 26 | x_axis = np.hstack([x_axis, np.linspace(i+0.05, i+1-0.05, base_signal.shape[0])]) 27 | y_axis.append(base_signal) 28 | 29 | return x_axis, np.hstack(y_axis) 30 | 31 | 32 | def plot_read(read_data, marker_transparency=0.8, line_plot=True, marker_size=20, lim=0, display_average=True, save_path=None): 33 | split_signals, seq=read_data['signal'], read_data['seq'] 34 | if lim==0: 35 | u_lim=np.max(np.hstack(split_signals)) 36 | l_lim=np.min(np.hstack(split_signals)) 37 | else: 38 | u_lim, l_lim=lim, -1*lim 39 | 40 | K=len(seq) 41 | 42 | p_x_axis, p_y_axis=get_x_and_y_axes(split_signals) 43 | 44 | plt.figure(figsize=(12, 3),dpi=100) 45 | 46 | if line_plot: 47 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color='green') 48 | plt.scatter(p_x_axis, p_y_axis,alpha=marker_transparency, s=marker_size, color='green', edgecolor='black') 49 | 50 | plt.axhline(y=0, linestyle='dotted') 51 | for i in range(K): 52 | plt.axvline(x=i, linestyle='dotted', ymin=-4, ymax=4) 53 | 54 | handles=[] 55 | if display_average: 56 | for i in range(0,K): 57 | plt.hlines(y=np.mean(split_signals[i]),xmin=i,xmax=i+1,color='red',linestyle='-') 58 | plt.hlines(y=np.median(split_signals[i]),xmin=i,xmax=i+1,color='blue',linestyle='-') 59 | 60 | handles.append(mpatches.Patch(color='red', label='Mean Signal')) 61 | handles.append(mpatches.Patch(color='blue', label='Median Signal')) 62 | 63 | plt.axis([0, K, 1.1*l_lim, 1.1*u_lim]) 64 | plt.xticks(np.arange(0, K) + 1/2, list(seq)) 65 | 66 | handles.append(Line2D([0], [0], marker='o', color='white', label='Signal', alpha=0.5, markeredgecolor='black', markerfacecolor='green', markersize=5)) 67 | leg=plt.legend(handles=handles,loc='upper right',fontsize=10) 68 | for patch in leg.get_patches(): 69 | patch.set_height(5) 70 | 71 | plt.ylabel('Signal') 72 | plt.xlabel('Bases') 73 | 74 | if save_path!=None: 75 | plt.savefig(save_path, bbox_inches='tight') 76 | 77 | plt.show() 78 | 79 | def plot_single_sample(data, line_plot=False, color='green',lim=0, marker_size=10, marker_transparency=0.2, display_average=True, save_path=None): 80 | plt.figure(figsize=(12, 3),dpi=100) 81 | 82 | if lim==0: 83 | u_lim=max(np.max(np.hstack(x['signal'])) for x in data.values()) 84 | l_lim=min(np.min(np.hstack(x['signal'])) for x in data.values()) 85 | else: 86 | u_lim, l_lim=lim, -1*lim 87 | 88 | 89 | cons_seq=get_consensus(data) 90 | K=len(cons_seq) 91 | 92 | bin_list=[[] for i in range(K)] 93 | 94 | for read_data in data.values(): 95 | for i in range(K): 96 | for signal in read_data['signal'][i]: 97 | bin_list[i].append(signal) 98 | 99 | p_x_axis, p_y_axis=get_x_and_y_axes(read_data['signal']) 100 | 101 | 102 | if line_plot: 103 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color=color) 104 | 105 | plt.scatter(p_x_axis, p_y_axis, alpha=marker_transparency, s=marker_size, color=color, edgecolor=None) 106 | 107 | plt.axhline(y=0, linestyle='dotted') 108 | for i in range(K): 109 | plt.axvline(x=i, linestyle='dotted', ymin=-4, ymax=4) 110 | 111 | handles=[] 112 | if display_average: 113 | for i in range(0,K): 114 | plt.hlines(y=np.mean(bin_list[i]),xmin=i,xmax=i+1,color='black',linestyle='-') 115 | plt.hlines(y=np.median(bin_list[i]),xmin=i,xmax=i+1,color='blue',linestyle='-') 116 | handles.append(mpatches.Patch(color='black', label='Mean Signal')) 117 | handles.append(mpatches.Patch(color='blue', label='Median Signal')) 118 | 119 | plt.axis([0, K, 1.1*l_lim, 1.1*u_lim]) 120 | plt.xticks(np.arange(0, K) + 1/2, cons_seq) 121 | 122 | handles.append(Line2D([0], [0], marker='o', color='white', label='Signal', alpha=0.5, markeredgecolor='black', markerfacecolor=color, markersize=5)) 123 | leg=plt.legend(handles=handles,loc='upper right',fontsize=10) 124 | for patch in leg.get_patches(): 125 | patch.set_height(5) 126 | plt.ylabel('Signal') 127 | plt.xlabel('Bases') 128 | 129 | if save_path!=None: 130 | plt.savefig(save_path,bbox_inches='tight') 131 | 132 | plt.show() 133 | 134 | def plot_two_samples(sample1_data, sample2_data, label1='Mod', label2='Unmod', line_plot=False, lim=0, marker_size=10, marker_transparency=0.2, display_average=True, average_type='median', save_path=None, color1='green', color2='red'): 135 | 136 | plt.figure(figsize=(12, 3),dpi=100) 137 | 138 | tmp_dict=dict(sample1_data) 139 | tmp_dict.update(sample2_data) 140 | cons_seq=get_consensus(tmp_dict) 141 | 142 | if lim==0: 143 | u_lim=max(np.max(np.hstack(x['signal'])) for x in tmp_dict.values()) 144 | l_lim=min(np.min(np.hstack(x['signal'])) for x in tmp_dict.values()) 145 | else: 146 | u_lim, l_lim=lim, -1*lim 147 | 148 | K=len(cons_seq) 149 | 150 | sample_1_bin_list=[[] for i in range(K)] 151 | sample_2_bin_list=[[] for i in range(K)] 152 | 153 | 154 | for read_data in sample1_data.values(): 155 | for i in range(K): 156 | for signal in read_data['signal'][i]: 157 | sample_1_bin_list[i].append(signal) 158 | 159 | p_x_axis, p_y_axis=get_x_and_y_axes(read_data['signal']) 160 | 161 | 162 | if line_plot: 163 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color=color1) 164 | 165 | plt.scatter(p_x_axis, p_y_axis, alpha=marker_transparency, s=marker_size, color=color1, edgecolor=None) 166 | 167 | for read_data in sample2_data.values(): 168 | for i in range(K): 169 | for signal in read_data['signal'][i]: 170 | sample_2_bin_list[i].append(signal) 171 | 172 | p_x_axis, p_y_axis=get_x_and_y_axes(read_data['signal']) 173 | 174 | 175 | if line_plot: 176 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color=color2) 177 | 178 | plt.scatter(p_x_axis, p_y_axis, alpha=marker_transparency, s=marker_size, color=color2, edgecolor=None) 179 | 180 | plt.axhline(y=0, linestyle='dotted') 181 | for i in range(K): 182 | plt.axvline(x=i, linestyle='dotted', ymin=-4, ymax=4) 183 | 184 | handles=[] 185 | if display_average: 186 | for i in range(0,K): 187 | if average_type=='mean': 188 | plt.hlines(y=np.mean(sample_1_bin_list[i]),xmin=i,xmax=i+1,color=color1,linestyle='-') 189 | plt.hlines(y=np.mean(sample_2_bin_list[i]),xmin=i,xmax=i+1,color=color2,linestyle='-') 190 | 191 | elif average_type=='median': 192 | plt.hlines(y=np.median(sample_1_bin_list[i]),xmin=i,xmax=i+1,color=color1,linestyle='-') 193 | plt.hlines(y=np.median(sample_2_bin_list[i]),xmin=i,xmax=i+1,color=color2,linestyle='-') 194 | 195 | handles.append(mpatches.Patch(color=color1, label='{} {} Signal'.format(label1, average_type.capitalize()))) 196 | handles.append(mpatches.Patch(color=color2, label='{} {} Signal'.format(label2, average_type.capitalize()))) 197 | 198 | plt.axis([0, K, 1.1*l_lim, 1.1*u_lim]) 199 | plt.xticks(np.arange(0, K) + 1/2, cons_seq) 200 | 201 | handles.append(Line2D([0], [0], marker='o', color='white', label='{} Signal'.format(label1), alpha=0.5, markeredgecolor='black', markerfacecolor='green', markersize=5)) 202 | handles.append(Line2D([0], [0], marker='o', color='white', label='{} Signal'.format(label2), alpha=0.5, markeredgecolor='black', markerfacecolor='red', markersize=5)) 203 | leg=plt.legend(handles=handles,loc='upper right',fontsize=10) 204 | 205 | for patch in leg.get_patches(): 206 | patch.set_height(5) 207 | plt.ylabel('Signal') 208 | plt.xlabel('Bases') 209 | 210 | if save_path!=None: 211 | plt.savefig(save_path,bbox_inches='tight') 212 | 213 | plt.show() 214 | 215 | def violin_plot(sample_data, avg_type='median', static_display=False, meanline_visible=False, figure_width=1000, figure_height=500, save_path=None): 216 | 217 | cons_seq=get_consensus(sample_data) 218 | 219 | K=len(cons_seq) 220 | d=[] 221 | 222 | 223 | j=0 224 | for i in range(K): 225 | for read_data in sample_data.values(): 226 | if avg_type=='mean': 227 | d.append([j,np.mean(read_data['signal'][i])]) 228 | elif avg_type=='median': 229 | d.append([j,np.median(read_data['signal'][i])]) 230 | j+=1 231 | 232 | df=pd.DataFrame(d) 233 | df.rename(columns={0:'Position',1:'Signal'}, inplace=True) 234 | 235 | fig = go.Figure() 236 | 237 | fig.add_trace(go.Violin(x=df['Position'], 238 | y=df['Signal'], 239 | points=False, 240 | line=dict(color="blue", width=0.5),meanline=dict(color="blue", width=2), 241 | meanline_visible=meanline_visible) 242 | ) 243 | 244 | tickvals=np.arange(0,len(cons_seq)) 245 | ticktext=['%d
%s' %(a-len(cons_seq)//2,b) for a,b in zip(tickvals, cons_seq)] 246 | fig.update_xaxes(tickmode='array', tickvals=tickvals, ticktext=ticktext) 247 | fig.update_layout(violingap=0, violinmode='overlay') 248 | 249 | 250 | 251 | if static_display: 252 | fig.update_layout(autosize=False, width=figure_width, height=figure_height) 253 | if save_path!=None: 254 | fig.write_image(save_path) 255 | fig.show(renderer="svg") 256 | else: 257 | if save_path!=None: 258 | fig.write_html(save_path) 259 | fig.show() 260 | 261 | return 262 | 263 | 264 | def compare_violin_plot(sample1_data, sample2_data, label1='Mod', label2='Unmod', avg_type='median', static_display=False, meanline_visible=False, figure_width=1000, figure_height=500, save_path=None, test_type='mw', test_method="auto", display_pval=True): 265 | tmp_dict=dict(sample1_data) 266 | tmp_dict.update(sample2_data) 267 | cons_seq=get_consensus(tmp_dict) 268 | 269 | K=len(cons_seq) 270 | d=[] 271 | 272 | 273 | j=0 274 | for i in range(K): 275 | for read_data in sample1_data.values(): 276 | if avg_type=='mean': 277 | d.append([j,np.mean(read_data['signal'][i])]) 278 | elif avg_type=='median': 279 | d.append([j,np.median(read_data['signal'][i])]) 280 | j+=1 281 | 282 | pos_df=pd.DataFrame(d) 283 | pos_df.rename(columns={0:'Position',1:'Signal'}, inplace=True) 284 | pos_df['Sample']=label1 285 | 286 | d=[] 287 | 288 | 289 | j=0 290 | for i in range(K): 291 | for read_data in sample2_data.values(): 292 | if avg_type=='mean': 293 | d.append([j,np.mean(read_data['signal'][i])]) 294 | elif avg_type=='median': 295 | d.append([j,np.median(read_data['signal'][i])]) 296 | j+=1 297 | 298 | neg_df=pd.DataFrame(d) 299 | neg_df.rename(columns={0:'Position',1:'Signal'}, inplace=True) 300 | neg_df['Sample']=label2 301 | 302 | df=pd.concat([pos_df, neg_df]) 303 | fig = go.Figure() 304 | 305 | fig.add_trace(go.Violin(x=df['Position'][ df['Sample'] == label1], 306 | y=df['Signal'][ df['Sample'] == label1 ], 307 | legendgroup=label1, scalegroup=label1, name=label1, points=False, 308 | side='negative', 309 | line=dict(color="blue", width=0.5),meanline=dict(color="blue", width=2), 310 | meanline_visible=meanline_visible) 311 | ) 312 | fig.add_trace(go.Violin(x=df['Position'][ df['Sample'] == label2], 313 | y=df['Signal'][ df['Sample'] == label2 ], 314 | legendgroup=label2, scalegroup=label2, name=label2, points=False, 315 | side='positive', 316 | line=dict(color="orange", width=0.5), meanline=dict(color="orange", width=2), 317 | meanline_visible=meanline_visible) 318 | ) 319 | 320 | 321 | group_df=df.groupby(['Position', 'Sample'])['Signal'].apply(list) 322 | dist_stats={} 323 | for i in range(len(group_df)//2): 324 | if test_type=='ks': 325 | s=stats.ks_2samp(group_df.loc[i,:].loc[label1], group_df.loc[i,:].loc[label2], method=test_method) 326 | elif test_type=="mw": 327 | s=stats.mannwhitneyu(group_df.loc[i,:].loc[label1], group_df.loc[i,:].loc[label2], method=test_method) 328 | dist_stats[i]={'Position':i,'Base':cons_seq[i],'Statistic':s.statistic, 'Pvalue':s.pvalue} 329 | 330 | dist_stats_df=pd.DataFrame(dist_stats).T 331 | dist_stats_df=dist_stats_df.astype({'Position': 'int', 'Statistic': 'float32', 'Pvalue': 'float32',}) 332 | 333 | tickvals=np.arange(0,len(cons_seq)) 334 | if display_pval: 335 | ticktext=['{}
{}
{:0.1e}'.format(a-len(cons_seq)//2,b,c) for a,b,c in zip(tickvals, cons_seq, dist_stats_df.Pvalue)] 336 | else: 337 | ticktext=['%d
%s' %(a-len(cons_seq)//2,b) for a,b in zip(tickvals, cons_seq)] 338 | fig.update_xaxes(tickmode='array', tickvals=tickvals, ticktext=ticktext,tickfont = dict(size = 8)) 339 | fig.update_layout(violingap=0, violinmode='overlay') 340 | 341 | if static_display: 342 | fig.update_layout(autosize=False, width=figure_width, height=figure_height) 343 | if save_path!=None: 344 | fig.write_image(save_path) 345 | fig.show(renderer="svg") 346 | else: 347 | if save_path!=None: 348 | fig.write_html(save_path) 349 | fig.show() 350 | 351 | return dist_stats_df 352 | 353 | 354 | def revcomp(s): 355 | return ''.join(comp_base_map[x] for x in s[::-1]) 356 | 357 | def get_consensus(data): 358 | K=len(next(iter(data.values()))['seq']) 359 | cons_seq_array=np.zeros((K,4)) 360 | for x in data.values(): 361 | for i in range(K): 362 | cons_seq_array[i][base_map[x['seq'][i]]]+=1 363 | 364 | cons_seq=''.join(rev_base_map[t] for t in np.argmax(cons_seq_array,axis=1)) 365 | return cons_seq 366 | 367 | def get_file_names(base_path): 368 | read_filename_dict={} 369 | 370 | if os.path.isdir(base_path): 371 | files=Path(base_path).rglob('*.pod5') 372 | else: 373 | files=[base_path] 374 | for read_path in files: 375 | read_path=str(read_path) 376 | with p5.Reader(read_path) as reader: 377 | for rname in reader.read_ids: 378 | read_filename_dict[rname]=read_path 379 | 380 | return read_filename_dict 381 | 382 | def get_read_positions(bam_path, chrom, pos, strand, seq_type): 383 | flag=0x4|0x100|0x200|0x400|0x800 384 | 385 | read_info={} 386 | 387 | bam=pysam.Samfile(bam_path,'rb') 388 | 389 | for pcol in bam.pileup(contig=chrom, start=pos-1, end=pos, flag_filter=flag, truncate=True, min_base_quality = 0): 390 | if strand=='+': 391 | for read in pcol.pileups: 392 | if read.alignment.is_reverse==False: 393 | if read.is_del: 394 | continue 395 | print('DEL', read.alignment.qname) 396 | else: 397 | if seq_type=='dna': 398 | read_info[read.alignment.qname]=(read.query_position, False, read.alignment.to_dict()) 399 | elif seq_type=='rna': 400 | read_info[read.alignment.qname]=(read.alignment.query_length-read.query_position-1, False, read.alignment.to_dict()) 401 | elif strand=='-': 402 | for read in pcol.pileups: 403 | if read.alignment.is_reverse: 404 | if read.is_del: 405 | continue 406 | print('DEL', read.alignment.qname) 407 | else: 408 | if seq_type=='dna': 409 | read_info[read.alignment.qname]=(read.alignment.query_length-read.query_position-1, True, read.alignment.to_dict()) 410 | 411 | elif seq_type=='rna': 412 | read_info[read.alignment.qname]=(read.query_position, True, read.alignment.to_dict()) 413 | 414 | return read_info 415 | 416 | def get_read_signal_raw(signal, move,norm_type): 417 | stride, start, move_table=move 418 | median=np.median(signal) 419 | mad=np.median(np.abs(signal-median)) 420 | 421 | if norm_type=='MAD': 422 | norm_signal=(signal-median)/mad 423 | elif norm_type=='STD': 424 | norm_signal=(signal-np.mean(signal))/np.std(signal) 425 | 426 | move_len=len(move_table) 427 | move_index=np.where(move_table)[0] 428 | rlen=len(move_index) 429 | 430 | base_level_data=[] 431 | for i in range(len(move_index)-1): 432 | prev=move_index[i]*stride+start 433 | sig_end=move_index[i+1]*stride+start 434 | base_level_data.append([prev,sig_end-prev]) 435 | 436 | return norm_signal, base_level_data 437 | 438 | def get_signals(bam_path, chrom, pos, strand, read_filename_dict, base_path, seq_type, max_cov=1000, window_before=10, window_after=10, norm_type='STD'): 439 | read_info=get_read_positions(bam_path, chrom, pos, strand,seq_type) 440 | data={} 441 | 442 | if seq_type=='rna': 443 | window_before, window_after=window_after, window_before 444 | 445 | cov=0 446 | for read_name in read_info.keys(): 447 | if cov > max_cov: 448 | break 449 | try: 450 | read_path=read_filename_dict[read_name] 451 | except KeyError: 452 | continue 453 | 454 | with p5.Reader(read_path) as reader: 455 | raw_read=next(reader.reads([read_name])) 456 | try: 457 | read_pos, reverse, read_dict=read_info[read_name] 458 | 459 | except KeyError: 460 | continue 461 | 462 | tags={x.split(':')[0]:x for x in read_dict['tags']} 463 | start=int(tags['ts'].split(':')[-1]) 464 | mv=tags['mv'].split(',') 465 | 466 | stride=int(mv[1]) 467 | move_table=np.fromiter(mv[2:], dtype=np.int8) 468 | move=(stride, start, move_table) 469 | 470 | signal=raw_read.signal 471 | 472 | fq=read_dict['seq'] 473 | fq=revcomp(fq) if reverse else fq 474 | 475 | if seq_type=='rna': 476 | fq=fq[::-1] 477 | 478 | norm_signal, base_level_data = get_read_signal_raw(signal, move, norm_type) 479 | seq_len=len(fq) 480 | 481 | if read_pos>window_before+5 and read_pos5]=5 39 | signal[signal<-5]=-5 40 | 41 | move_len=len(move_table) 42 | 43 | data=np.zeros((move_len,stride+2)) 44 | indexes=np.full(move_len,fill_value=0,dtype=np.int32) 45 | z=1 46 | idx=-1 47 | 48 | segments=np.full(np.sum(move_table),fill_value=0,dtype=np.int32) 49 | 50 | for i in range(move_len): 51 | if move_table[i]: 52 | z=z^1 53 | idx+=1 54 | segments[idx]=i 55 | 56 | data[i,z]=1 57 | indexes[i]=idx 58 | for k in range(stride): 59 | data[i,2+k]=signal[start+i*stride+k] 60 | 61 | return data, indexes, segments 62 | 63 | @jit(nopython=True) 64 | def get_events(signal, move, norm_type): 65 | stride, start, move_table=move 66 | 67 | if norm_type=='mad': 68 | median=np.median(signal) 69 | mad=np.median(np.abs(signal-median)) 70 | signal=(signal-median)/mad 71 | 72 | else: 73 | mean=np.mean(signal) 74 | std=np.std(signal) 75 | signal=(signal-mean)/std 76 | 77 | move_len=len(move_table) 78 | move_index=np.where(move_table)[0] 79 | rlen=len(move_index) 80 | 81 | data=np.zeros((rlen,9)) 82 | 83 | for i in range(len(move_index)-1): 84 | prev=move_index[i]*stride+start 85 | sig_end=move_index[i+1]*stride+start 86 | 87 | sig_len=sig_end-prev 88 | data[i, 8]=np.log10(sig_len) 89 | data[i, 4]=np.median(signal[prev:sig_end]) 90 | data[i, 5]=np.median(np.abs(signal[prev:sig_end]-data[i, 4])) 91 | data[i, 6]=np.mean(signal[prev:sig_end]) 92 | data[i, 7]=np.std(signal[prev:sig_end]) 93 | 94 | for j in range(4): 95 | tmp_cnt=0 96 | for t in range(j*sig_len//4,min(sig_len, (j+1)*sig_len//4)): 97 | data[i, j]+=signal[t+prev] 98 | tmp_cnt+=1 99 | data[i, j]=data[i, j]/tmp_cnt 100 | 101 | return data 102 | 103 | def get_candidates(read_seq, align_data, aligned_pairs, ref_pos_dict, exp_motif_seq, motif_base, motif_ind, position_based): 104 | if align_data[0]: 105 | is_mapped, is_forward, ref_name, reference_start, reference_end, read_length=align_data 106 | 107 | base_id={m.start(0):i for i,m in enumerate(re.finditer(r'(?={})'.format(motif_base), read_seq))} 108 | 109 | motif_anchor=np.array([m.start(0) for m in re.finditer(r'(?={})'.format(exp_motif_seq), read_seq)]) 110 | motif_id=np.array(sorted(list(set.union(*[set(motif_anchor+i) for i in motif_ind])))) 111 | 112 | ref_motif_pos=ref_pos_dict[ref_name][0] if is_forward else ref_pos_dict[ref_name][1] 113 | 114 | common_pos=ref_motif_pos[(ref_motif_pos>=reference_start)&(ref_motif_pos0 and not position_based: 127 | aligned_pairs_read_wise=aligned_pairs_read_wise[motif_id] 128 | 129 | #if need to disable clipped bases 130 | #aligned_pairs_read_wise=aligned_pairs_read_wise[(reference_start<=aligned_pairs_read_wise[:,1]) & (aligned_pairs_read_wise[:,1]counter_check: 202 | counter_check=counter//10000 203 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True) 204 | if res[0]: 205 | _, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list = res 206 | for read_data, candidate_list, MM, ML, pred_list in zip(*res[1:]): 207 | counter+=1 208 | read_dict, read_info = read_data 209 | read=pysam.AlignedSegment.from_dict(read_dict,header) 210 | if MM: 211 | read.set_tag('MM',MM,value_type='Z') 212 | read.set_tag('ML',ML) 213 | 214 | outf.write(read) 215 | 216 | read_name=read_dict['name'] 217 | is_forward, chrom, read_length, mean_qscore=read_info 218 | chrom=chrom if chrom else 'NA' 219 | 220 | strand='+' if is_forward else '-' 221 | 222 | phase=0 223 | phase=read.get_tag('HP') if read.has_tag('HP') else 0 224 | 225 | idx=4*phase+2*is_forward 226 | 227 | if float(mean_qscore)unmod_threshold: 249 | pass 250 | elif not skip_per_site: 251 | mod=score>=mod_threshold 252 | 253 | if (chrom, zero_based_fwd_pos) not in per_site_pred: 254 | per_site_pred[(chrom, zero_based_fwd_pos)]=[0]*12+[is_ref_cpg] 255 | 256 | per_site_pred[(chrom, zero_based_fwd_pos)][idx+mod]+=1 257 | 258 | per_read_file.write('%s\t%s\t%s\t%s\t%d\t%s\t%.4f\t%.2f\t%d\t%d\t%s\n' %(read_name, chrom, ref_pos_str_before, ref_pos_str_after, read_pos, strand, score, mean_qscore, read_length, phase, is_ref_cpg)) 259 | 260 | 261 | else: 262 | _, total_read_info=res 263 | for read_dict in total_read_info: 264 | counter+=1 265 | read=pysam.AlignedSegment.from_dict(read_dict,header) 266 | outf.write(read) 267 | 268 | except queue.Empty: 269 | pass 270 | 271 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True) 272 | print('%s: Finished Per-Read Methylation Output. Starting Per-Site output.' %str(datetime.datetime.now()), flush=True) 273 | print('%s: Modification Tagged BAM file: %s' %(str(datetime.datetime.now()),bam_output), flush=True) 274 | print('%s: Per Read Prediction file: %s' %(str(datetime.datetime.now()), per_read_file_path), flush=True) 275 | print('%s: Writing Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True) 276 | 277 | if skip_per_site: 278 | return 279 | per_site_fields=['#chromosome', 'position_before', 'position','strand', 'ref_cpg', 280 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction', 281 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1', 282 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2'] 283 | per_site_header='\t'.join(per_site_fields)+'\n' 284 | per_site_fields.remove('strand') 285 | agg_per_site_header='\t'.join(per_site_fields)+'\n' 286 | 287 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix']) 288 | agg_per_site_file_path=os.path.join(params['output'],'%s.per_site.aggregated' %params['prefix']) 289 | 290 | with open(per_site_file_path, 'w') as per_site_file, open(agg_per_site_file_path,'w') as agg_per_site_file: 291 | per_site_file.write(per_site_header) 292 | agg_per_site_file.write(agg_per_site_header) 293 | 294 | for x in sorted(per_site_pred.keys()): 295 | chrom, pos=x 296 | cpg=per_site_pred[x] 297 | is_ref_cpg=cpg[12] 298 | 299 | if cpg_ref_only and is_ref_cpg==False: 300 | continue 301 | #fwd_stats=[self.chrom, self.position, self.position+1, '+', self.is_ref_cpg]+self.get_all_phases().forward.stats() + self.phase_1.forward.stats() + self.phase_2.forward.stats() 302 | 303 | agg_stats, fwd_stats, rev_stats=get_stats_string_cpg(chrom, pos, is_ref_cpg, cpg) 304 | if agg_stats[0]>0: 305 | agg_per_site_file.write(agg_stats[1]) 306 | 307 | if fwd_stats[0]>0: 308 | per_site_file.write(fwd_stats[1]) 309 | 310 | if rev_stats[0]>0: 311 | per_site_file.write(rev_stats[1]) 312 | 313 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True) 314 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True) 315 | print('%s: Aggregated Per Site Prediction file: %s' %(str(datetime.datetime.now()), agg_per_site_file_path), flush=True) 316 | 317 | return 318 | 319 | def get_output(params, output_Q, methylation_event, header_dict, ref_pos_dict): 320 | header=pysam.AlignmentHeader.from_dict(header_dict) 321 | 322 | bam_threads=params['bam_threads'] 323 | 324 | output=params['output'] 325 | bam_output=os.path.join(output,'%s.bam' %params['prefix']) 326 | per_read_file_path=os.path.join(output,'%s.per_read' %params['prefix']) 327 | 328 | 329 | per_site_file_path=os.path.join(output,'%s.per_site' %params['prefix']) 330 | agg_per_site_file_path=os.path.join(output,'%s.per_site.aggregated' %params['prefix']) 331 | qscore_cutoff=params['qscore_cutoff'] 332 | length_cutoff=params['length_cutoff'] 333 | 334 | mod_threshold=params['mod_t'] 335 | unmod_threshold=params['unmod_t'] 336 | 337 | skip_per_site=params['skip_per_site'] 338 | 339 | per_site_pred={} 340 | 341 | counter=0 342 | 343 | counter_check=0 344 | with open(per_read_file_path,'w') as per_read_file: 345 | per_read_file.write('read_name\tchromosome\tref_position_before\tref_position\tread_position\tstrand\tmethylation_score\tmean_read_qscore\tread_length\tread_phase\n') 346 | 347 | with pysam.AlignmentFile(bam_output, "wb", threads=bam_threads, header=header) as outf: 348 | while True: 349 | if methylation_event.is_set() and output_Q.empty(): 350 | break 351 | else: 352 | try: 353 | res = output_Q.get(block=False, timeout=10) 354 | #continue 355 | if counter//10000>counter_check: 356 | counter_check=counter//10000 357 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True) 358 | if res[0]: 359 | _, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list = res 360 | for read_data, candidate_list, MM, ML, pred_list in zip(*res[1:]): 361 | counter+=1 362 | read_dict, read_info = read_data 363 | read=pysam.AlignedSegment.from_dict(read_dict,header) 364 | if MM: 365 | read.set_tag('MM',MM,value_type='Z') 366 | read.set_tag('ML',ML) 367 | 368 | outf.write(read) 369 | 370 | read_name=read_dict['name'] 371 | is_forward, chrom, read_length, mean_qscore=read_info 372 | chrom=chrom if chrom else 'NA' 373 | 374 | strand='+' if is_forward else '-' 375 | 376 | phase=0 377 | phase=read.get_tag('HP') if read.has_tag('HP') else 0 378 | 379 | 380 | if float(mean_qscore)unmod_threshold: 396 | pass 397 | elif not skip_per_site: 398 | mod=score>=mod_threshold 399 | 400 | if (chrom, ref_pos,strand) not in per_site_pred: 401 | per_site_pred[(chrom, ref_pos,strand)]=[0]*6 402 | 403 | per_site_pred[(chrom, ref_pos,strand)][2*phase+mod]+=1 404 | 405 | per_read_file.write('%s\t%s\t%s\t%s\t%d\t%s\t%.4f\t%.2f\t%d\t%d\n' %(read_name, chrom, ref_pos_str_before, ref_pos_str_after, read_pos, strand, score, mean_qscore, read_length, phase)) 406 | 407 | 408 | else: 409 | _, total_read_info=res 410 | for read_dict in total_read_info: 411 | counter+=1 412 | read=pysam.AlignedSegment.from_dict(read_dict,header) 413 | outf.write(read) 414 | 415 | except queue.Empty: 416 | pass 417 | 418 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True) 419 | print('%s: Finished Per-Read Methylation Output. Starting Per-Site output.' %str(datetime.datetime.now()), flush=True) 420 | print('%s: Modification Tagged BAM file: %s' %(str(datetime.datetime.now()),bam_output), flush=True) 421 | print('%s: Per Read Prediction file: %s' %(str(datetime.datetime.now()), per_read_file_path), flush=True) 422 | 423 | if skip_per_site: 424 | return 425 | per_site_fields=['#chromosome', 'position_before', 'position','strand', 426 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction', 427 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1', 428 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2'] 429 | per_site_header='\t'.join(per_site_fields)+'\n' 430 | 431 | 432 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix']) 433 | 434 | with open(per_site_file_path, 'w') as per_site_file: 435 | per_site_file.write(per_site_header) 436 | 437 | for x in sorted(per_site_pred.keys()): 438 | chrom, pos, strand=x 439 | mod_call=per_site_pred[x] 440 | 441 | stats=get_stats_string(chrom, pos, strand, mod_call) 442 | if stats[0]>0: 443 | per_site_file.write(stats[1]) 444 | 445 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True) 446 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True) 447 | return 448 | 449 | def process(params,ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict): 450 | torch.set_grad_enabled(False); 451 | 452 | dev=params['dev'] 453 | motif_seq=params['motif_seq'] 454 | exp_motif_seq=params['exp_motif_seq'] 455 | motif_base=motif_seq[params['motif_ind'][0]] 456 | motif_ind=params['motif_ind'] 457 | 458 | if params['mod_symbol']: 459 | mod_symbol=params['mod_symbol'] 460 | elif motif_seq=='CG': 461 | mod_symbol='m' 462 | else: 463 | mod_symbol=motif_base 464 | 465 | seq_type=params['seq_type'] 466 | 467 | position_based=params['position_based'] 468 | 469 | base_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3} 470 | 471 | cigar_map={'M':0, '=':0, 'X':0, 'D':1, 'I':2, 'S':2,'H':2, 'N':1, 'P':4, 'B':4} 472 | cigar_pattern = r'\d+[A-Za-z=]' 473 | 474 | model, model_config=get_model(params) 475 | window=model_config['window'] 476 | full_signal=model_config['full_signal'] 477 | strides_per_base=model_config['strides_per_base'] 478 | norm_type=model_config['norm_type'] 479 | 480 | model.eval() 481 | model.to(dev); 482 | print('CUDA test:', torch.cuda.is_available(), next(model.parameters()).is_cuda, flush=True) 483 | 484 | reads_per_round=100 485 | 486 | chunk_size=256 if dev=='cpu' else params['batch_size'] 487 | 488 | total_candidate_list=[] 489 | total_feature_list=[] 490 | total_base_seq_list=[] 491 | total_MM_list=[] 492 | total_read_info=[] 493 | total_c_idx=[] 494 | total_unprocessed_reads=[] 495 | total_ref_seq_list=[] 496 | r_count=0 497 | 498 | dummy_ref_seq=4+np.zeros(1000000) 499 | 500 | ref_available=True if params['ref'] else False 501 | 502 | while True: 503 | if (signal_Q.empty() and input_event.is_set()): 504 | break 505 | 506 | try: 507 | 508 | chunk=signal_Q.get(block=False, timeout=10) 509 | #print('%s: Output_qsize=%d Signal_qsize=%d' %(str(datetime.datetime.now()), output_Q.qsize(), signal_Q.qsize()),flush=True) 510 | if output_Q.qsize()>200: 511 | time.sleep(30) 512 | if output_Q.qsize()>500: 513 | time.sleep(60) 514 | print('Pausing output due to queue size limit. Output_qsize=%d Signal_qsize=%d' %(output_Q.qsize(), signal_Q.qsize()), flush=True) 515 | for data in chunk: 516 | signal, move, read_dict, align_data=data 517 | 518 | is_mapped, is_forward, ref_name, reference_start, reference_end, read_length=align_data 519 | 520 | fq=read_dict['seq'] 521 | qual=read_dict['qual'] 522 | sequence_length=len(fq) 523 | reverse= not is_forward 524 | fq=revcomp(fq) if reverse else fq 525 | qual=qual[::-1] if reverse else qual 526 | 527 | if is_mapped and ref_available: 528 | cigar_tuples = np.array([(int(x[:-1]), cigar_map[x[-1]]) for x in re.findall(cigar_pattern, read_dict['cigar'])]) 529 | ref_start=int(read_dict['ref_pos'])-1 530 | aligned_pairs=get_aligned_pairs(cigar_tuples, ref_start) 531 | else: 532 | aligned_pairs=None 533 | 534 | 535 | pos_list_c, pos_list_candidates, read_to_ref_pairs=get_candidates(fq, align_data, aligned_pairs, ref_pos_dict, exp_motif_seq, motif_base, motif_ind, position_based) 536 | 537 | pos_list_candidates=pos_list_candidates[(pos_list_candidates[:,0]>window*strides_per_base)\ 538 | &(pos_list_candidates[:,0]0 else pos_list_candidates 539 | 540 | if len(pos_list_candidates)==0: 541 | total_unprocessed_reads.append(read_dict) 542 | continue 543 | 544 | if not move[0]: 545 | try: 546 | tags={x.split(':')[0]:x for x in read_dict.pop('tags')} 547 | start=int(tags['ts'].split(':')[-1]) 548 | mv=tags['mv'].split(',') 549 | 550 | stride=int(mv[1]) 551 | move_table=np.fromiter(mv[2:], dtype=np.int8) 552 | move=(stride, start, move_table) 553 | read_dict['tags']=[x for x in tags.values() if x[:2] not in ['mv', 'ts', 'ML', 'MM']] 554 | except KeyError: 555 | print('Read:%s no move table or stride or signal start found' %read_dict['name']) 556 | total_unprocessed_reads.append(read_dict) 557 | continue 558 | 559 | base_seq=np.array([base_map[x] for x in fq]) 560 | base_qual=10**((33-np.array([ord(x) for x in qual]))/10) 561 | mean_qscore=-10*np.log10(np.mean(base_qual)) 562 | base_qual=(1-base_qual)[:,np.newaxis] 563 | 564 | if is_mapped and ref_available and not params['exclude_ref_features']: 565 | ref_seq=ref_seq_dict[ref_name][:,1][read_to_ref_pairs[:, 1]][::-1] if reverse else ref_seq_dict[ref_name][:,0][read_to_ref_pairs[:, 1]] 566 | 567 | else: 568 | ref_seq=dummy_ref_seq 569 | 570 | 571 | if full_signal: 572 | mat, indexes, segments=get_segment_events(signal, move, norm_type) 573 | segments_ends=np.concatenate([segments[1:],np.array([len(move_table)])]) 574 | segment_ranges=np.vstack([segments, segments_ends]).T 575 | if seq_type=='rna': 576 | mat=np.flip(mat,axis=0) 577 | indexes=len(read_dict['seq'])-np.flip(indexes,axis=0)-1 578 | segment_ranges=np.flip(np.flip(len(move_table)-segment_ranges,axis=1),axis=0) 579 | else: 580 | mat=get_events(signal, move, norm_type) 581 | if seq_type=='rna': 582 | mat=np.flip(mat,axis=0) 583 | mat=np.hstack((mat, base_qual)) 584 | 585 | try: 586 | c_idx=[True if x in pos_list_c else False for x in pos_list_candidates[:,0]] 587 | c_idx_count=np.vectorize(pos_list_c.get)(pos_list_candidates[c_idx,0]) 588 | c_idx_count[1:]=c_idx_count[1:]-c_idx_count[:-1]-1 589 | MM='{}+{}?,'.format(motif_base,mod_symbol)+','.join(c_idx_count.astype(str))+';' 590 | total_c_idx.append(c_idx) 591 | total_MM_list.append(MM) 592 | 593 | except ValueError: 594 | total_c_idx.append([]) 595 | total_MM_list.append(None) 596 | 597 | 598 | if full_signal: 599 | base_seq_full=base_seq[indexes] 600 | ref_seq_full=ref_seq[indexes] 601 | base_qual_full=base_qual[indexes] 602 | per_site_features=np.array([mat[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates]) 603 | per_site_base_qual=np.array([base_qual_full[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates]) 604 | per_site_indexes=np.array([(indexes==candidate[0])[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates]) 605 | per_site_features=np.dstack([per_site_features, per_site_indexes[:,:,np.newaxis], per_site_base_qual]) 606 | 607 | per_site_base_seq=np.array([base_seq_full[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates]) 608 | per_site_ref_seq=np.array([ref_seq_full[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates]) 609 | 610 | 611 | else: 612 | per_site_features=np.array([mat[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 613 | per_site_base_seq=np.array([base_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 614 | per_site_ref_seq=np.array([ref_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 615 | 616 | total_candidate_list.append(pos_list_candidates) 617 | total_feature_list.append(per_site_features) 618 | total_base_seq_list.append(per_site_base_seq) 619 | total_ref_seq_list.append(per_site_ref_seq) 620 | 621 | total_read_info.append((read_dict, [align_data[1],align_data[2],align_data[5], mean_qscore])) 622 | 623 | if len(total_read_info)>=reads_per_round: 624 | read_counts=np.cumsum([len(x) for x in total_feature_list])[:-1] 625 | features_list=np.vstack(total_feature_list) 626 | base_seq_list=np.vstack(total_base_seq_list) 627 | ref_seq_list=np.vstack(total_ref_seq_list) 628 | 629 | pred_list=[model(batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev)).cpu().numpy() for batch_x, batch_base_seq, batch_ref_seq in generate_batches(features_list, base_seq_list, window, ref_seq=ref_seq_list, batch_size = chunk_size)] 630 | 631 | 632 | pred_list=np.vstack(pred_list) 633 | pred_list=np.split(pred_list.ravel(), read_counts) 634 | read_qual_list=[array.array('B',np.round(255*read_pred_list[c_idx]).astype(int)) for read_pred_list, c_idx in zip(pred_list, total_c_idx)] 635 | 636 | output_Q.put([True, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list]) 637 | total_candidate_list, total_feature_list, total_base_seq_list, total_MM_list, total_read_info, total_c_idx=[], [], [], [], [], [] 638 | total_ref_seq_list=[] 639 | 640 | if len(total_unprocessed_reads)>100: 641 | output_Q.put([False, total_unprocessed_reads]) 642 | total_unprocessed_reads=[] 643 | 644 | except queue.Empty: 645 | pass 646 | 647 | if len(total_read_info)>0: 648 | read_counts=np.cumsum([len(x) for x in total_feature_list])[:-1] 649 | features_list=np.vstack(total_feature_list) 650 | base_seq_list=np.vstack(total_base_seq_list) 651 | ref_seq_list=np.vstack(total_ref_seq_list) 652 | 653 | pred_list=[model(batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev)).cpu().numpy() for batch_x, batch_base_seq, batch_ref_seq in generate_batches(features_list, base_seq_list, window, ref_seq=ref_seq_list, batch_size = chunk_size)] 654 | 655 | pred_list=np.vstack(pred_list) 656 | pred_list=np.split(pred_list.ravel(), read_counts) 657 | read_qual_list=[array.array('B',np.round(255*read_pred_list[c_idx]).astype(int)) for read_pred_list, c_idx in zip(pred_list, total_c_idx)] 658 | 659 | output_Q.put([True, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list]) 660 | 661 | if len(total_unprocessed_reads)>0: 662 | output_Q.put([False, total_unprocessed_reads]) 663 | 664 | return 665 | 666 | def get_input(params, signal_Q, output_Q, input_event): 667 | chrom_list=params['chrom_list'] 668 | length_cutoff=params['length_cutoff'] 669 | 670 | skip_unmapped=params['skip_unmapped'] 671 | 672 | bam=params['bam'] 673 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False) 674 | 675 | print('%s: Building BAM index.' %str(datetime.datetime.now()), flush=True) 676 | bam_index=pysam.IndexedReads(bam_file) 677 | bam_index.build() 678 | print('%s: Finished building BAM index.' %str(datetime.datetime.now()), flush=True) 679 | 680 | input_=params['input'] 681 | signal_files= [input_] if os.path.isfile(input_) else Path(input_).rglob("*.%s" %params['file_type']) 682 | 683 | chunk=[] 684 | non_primary_reads=[] 685 | reads_per_chunk=100 686 | 687 | max_qsize=1000 if params['seq_type']=='rna' else 200 688 | 689 | if params['file_type']=='fast5': 690 | guppy_group=params['guppy_group'] 691 | for filename in signal_files: 692 | with get_fast5_file(filename, mode="r") as f5: 693 | for read in f5.get_reads(): 694 | if signal_Q.qsize()>200: 695 | time.sleep(20) 696 | #print('Pausing input due to INPUT queue size limit. Signal_qsize=%d' %(signal_Q.qsize()), flush=True) 697 | read_name=read.read_id 698 | non_primary_reads=[] 699 | try: 700 | read_iter=bam_index.find(read_name) 701 | for bam_read in read_iter: 702 | if (params['ref'] and bam_read.is_mapped and bam_read.reference_name not in chrom_list)\ 703 | or bam_read.query_length < length_cutoff \ 704 | or (bam_read.is_mapped==False and skip_unmapped==True): 705 | continue 706 | 707 | elif not (bam_read.is_supplementary or bam_read.is_secondary): 708 | read_dict=bam_read.to_dict() 709 | signal=read.get_raw_data() 710 | 711 | if params['fast5_move']: 712 | segment=read.get_analysis_attributes(guppy_group)['segmentation'] 713 | start=read.get_analysis_attributes('%s/Summary/segmentation' %segment)['first_sample_template'] 714 | stride=read.get_summary_data(guppy_group)['basecall_1d_template']['block_stride'] 715 | move_table=read.get_analysis_dataset('%s/BaseCalled_template' %guppy_group, 'Move') 716 | move=(stride, start, move_table) 717 | else: 718 | move=(None,None,None) 719 | 720 | align_data=(bam_read.is_mapped if params['ref'] else False, 721 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length) 722 | data=(signal, move, read_dict, align_data) 723 | chunk.append(data) 724 | if len(chunk)>=reads_per_chunk: 725 | signal_Q.put(chunk) 726 | chunk=[] 727 | 728 | else: 729 | pass 730 | #non_primary_reads.append(read_dict) 731 | 732 | if len(non_primary_reads)>0: 733 | output_Q.put([False, non_primary_reads]) 734 | 735 | except KeyError: 736 | continue 737 | 738 | else: 739 | move=(None,None,None) 740 | for filename in signal_files: 741 | with p5.Reader(filename) as reader: 742 | for read in reader.reads(): 743 | if signal_Q.qsize()>max_qsize: 744 | time.sleep(20) 745 | print('Pausing input due to INPUT queue size limit. Signal_qsize=%d' %(signal_Q.qsize()), flush=True) 746 | 747 | read_name=str(read.read_id) 748 | non_primary_reads=[] 749 | try: 750 | read_iter=bam_index.find(read_name) 751 | for bam_read in read_iter: 752 | 753 | if (params['ref'] and bam_read.is_mapped and bam_read.reference_name not in chrom_list)\ 754 | or bam_read.query_length < length_cutoff \ 755 | or (bam_read.is_mapped==False and skip_unmapped==True): 756 | continue 757 | 758 | elif not (bam_read.is_supplementary or bam_read.is_secondary): 759 | read_dict=bam_read.to_dict() 760 | signal=read.signal 761 | align_data=(bam_read.is_mapped if params['ref'] else False, 762 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length) 763 | data=(signal, move, read_dict, align_data) 764 | chunk.append(data) 765 | if len(chunk)>=reads_per_chunk: 766 | signal_Q.put(chunk) 767 | chunk=[] 768 | 769 | else: 770 | pass 771 | #non_primary_reads.append(read_dict) 772 | 773 | if len(non_primary_reads)>0: 774 | output_Q.put([False, non_primary_reads]) 775 | 776 | except KeyError: 777 | #print('Read:%s not found in BAM file' %read_name, flush=True) 778 | continue 779 | 780 | if len(chunk)>0: 781 | signal_Q.put(chunk) 782 | chunk=[] 783 | 784 | if len(non_primary_reads)>0: 785 | output_Q.put([False, non_primary_reads]) 786 | 787 | input_event.set() 788 | return 789 | 790 | def call_manager(params): 791 | print('%s: Starting Per Read Methylation Detection.' %str(datetime.datetime.now()), flush=True) 792 | if params['dev']!='cpu': 793 | torch.multiprocessing.set_start_method('spawn') 794 | 795 | torch.set_num_threads(1) 796 | 797 | pmanager = mp.Manager() 798 | 799 | bam=params['bam'] 800 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False) 801 | header_dict=bam_file.header.to_dict() 802 | 803 | print('%s: Getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True) 804 | 805 | ref_seq_dict={} 806 | ref_pos_dict={} 807 | 808 | mod_positions_list=get_pos(params['mod_positions']) if params['mod_positions'] else None 809 | position_based=True if params['mod_positions'] else False 810 | 811 | if position_based: 812 | params['chrom_list']=[x for x in params['chrom_list'] if x in mod_positions_list.keys()] 813 | 814 | _=get_ref_to_num('ACGT') 815 | 816 | if params['ref'] and len(params['chrom_list'])>0: 817 | with mp.Pool(processes=params['threads']) as pool: 818 | res=pool.map(get_ref_info, zip(repeat(params), params['chrom_list'])) 819 | for r in res: 820 | chrom, seq_array, fwd_pos_array, rev_pos_array=r 821 | ref_seq_dict[chrom]=seq_array 822 | 823 | if position_based: 824 | ref_pos_dict[chrom]=(np.array(sorted(list(set(fwd_pos_array)&set(mod_positions_list[chrom][0])))), 825 | np.array(sorted(list(set(rev_pos_array)&set(mod_positions_list[chrom][1]))))) 826 | 827 | else: 828 | ref_pos_dict[chrom]=(fwd_pos_array, rev_pos_array) 829 | 830 | 831 | params['position_based']=True if position_based or params['reference_motif_only'] else False 832 | 833 | res=None 834 | 835 | print('%s: Finished getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True) 836 | 837 | signal_Q = pmanager.Queue() 838 | output_Q = pmanager.Queue() 839 | methylation_event=pmanager.Event() 840 | input_event=pmanager.Event() 841 | 842 | handlers = [] 843 | 844 | input_process = mp.Process(target=get_input, args=(params, signal_Q, output_Q, input_event)) 845 | input_process.start() 846 | 847 | if params['motif_seq']=='CG': 848 | output_process=mp.Process(target=get_cpg_output, args=(params, output_Q, methylation_event, header_dict, ref_pos_dict)); 849 | else: 850 | output_process=mp.Process(target=get_output, args=(params, output_Q, methylation_event, header_dict, ref_pos_dict)); 851 | output_process.start(); 852 | 853 | for hid in range(max(1,params['threads']-1)): 854 | p = mp.Process(target=process, args=(params, ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict)); 855 | p.start(); 856 | handlers.append(p); 857 | 858 | input_process.join() 859 | print('%s: Reading inputs complete.' %str(datetime.datetime.now()), flush=True) 860 | 861 | for job in handlers: 862 | job.join() 863 | 864 | methylation_event.set() 865 | 866 | print('%s: Model predictions complete. Wrapping up output.' %str(datetime.datetime.now()), flush=True) 867 | 868 | output_process.join() 869 | 870 | return 871 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 5 | from torch import Tensor 6 | 7 | class OneHotEncode(nn.Module): 8 | def __init__(self, num_classes: int): 9 | super().__init__() 10 | self.num_classes=num_classes 11 | def forward(self, x: Tensor) -> Tensor: 12 | return F.one_hot(x, self.num_classes) 13 | 14 | class RefandReadEmbed(nn.Module): 15 | def __init__(self, embedding_dim, embedding_type): 16 | super().__init__() 17 | 18 | self.embedding_depth=0 19 | 20 | if embedding_type=='one_hot': 21 | self.read_emb=OneHotEncode(4) 22 | self.embedding_depth+=4 23 | 24 | self.ref_emb=OneHotEncode(5) 25 | self.embedding_depth+=5 26 | 27 | elif embedding_type=='learnable': 28 | self.read_emb=nn.Embedding(4, embedding_dim) 29 | self.embedding_depth+=embedding_dim 30 | 31 | self.ref_emb=nn.Embedding(5, embedding_dim) 32 | self.embedding_depth+=embedding_dim 33 | 34 | def forward(self, batch_base_seq, batch_ref_seq): 35 | batch_base_seq_emb=self.read_emb(batch_base_seq) 36 | batch_ref_seq_emb=self.ref_emb(batch_ref_seq) 37 | 38 | return torch.cat((batch_base_seq_emb, batch_ref_seq_emb), 2) 39 | 40 | class ReadEmbed(nn.Module): 41 | def __init__(self, embedding_dim, embedding_type): 42 | super().__init__() 43 | 44 | self.embedding_depth=0 45 | 46 | if embedding_type=='one_hot': 47 | self.read_emb=OneHotEncode(4) 48 | self.embedding_depth+=4 49 | 50 | elif embedding_type=='learnable': 51 | self.read_emb=nn.Embedding(4, embedding_dim) 52 | self.embedding_depth+=embedding_dim 53 | 54 | 55 | def forward(self, batch_base_seq, batch_ref_seq): 56 | batch_base_seq_emb=self.read_emb(batch_base_seq) 57 | 58 | return batch_base_seq_emb 59 | 60 | class SeqEmbed(nn.Module): 61 | def __init__(self, embedding_dim, embedding_type, include_ref): 62 | super().__init__() 63 | 64 | self.embedding_depth=0 65 | 66 | if include_ref: 67 | self.seq_emb=RefandReadEmbed(embedding_dim, embedding_type) 68 | 69 | else: 70 | self.seq_emb=ReadEmbed(embedding_dim, embedding_type) 71 | 72 | self.embedding_depth=self.seq_emb.embedding_depth 73 | 74 | def forward(self, batch_base_seq, batch_ref_seq): 75 | return self.seq_emb(batch_base_seq, batch_ref_seq) 76 | 77 | class PositionalEncoding(nn.Module): 78 | def __init__(self, pe_dim: int, max_len: int): 79 | super().__init__() 80 | position = torch.arange(max_len).unsqueeze(1) 81 | div_term = torch.exp(torch.arange(0, pe_dim, 2) * (-math.log(pe_dim) / (pe_dim))) 82 | pe = torch.zeros(1, max_len, pe_dim) 83 | pe[0,:, 0::2] = torch.sin(position * div_term) 84 | pe[0,:, 1::2] = torch.cos(position * div_term) 85 | self.register_buffer('pe', pe) 86 | 87 | def forward(self, x: Tensor) -> Tensor: 88 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1)) 89 | x = torch.cat((x, x_pos),2) 90 | return x 91 | 92 | class PositionalEmbedding(nn.Module): 93 | def __init__(self, pe_dim: int, max_len: int): 94 | super().__init__() 95 | pos=torch.arange(max_len) 96 | self.register_buffer('pos', pos) 97 | self.pe=nn.Embedding(max_len, pe_dim) 98 | 99 | def forward(self, x: Tensor) -> Tensor: 100 | x_pos=self.pe(self.pos) 101 | x_pos=torch.Tensor.repeat(x_pos,(x.size(0),1,1)) 102 | x = torch.cat((x, x_pos),2) 103 | return x 104 | 105 | class PositionalParameter(nn.Module): 106 | def __init__(self, pe_dim: int, max_len: int): 107 | super().__init__() 108 | 109 | self.pe=torch.nn.Parameter(torch.randn(max_len, pe_dim)) 110 | 111 | def forward(self, x: Tensor) -> Tensor: 112 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1)) 113 | x = torch.cat((x, x_pos),2) 114 | return x 115 | 116 | class ClassifierMiddle(nn.Module): 117 | def __init__(self, in_dim: int, num_fc: int, model_len: int): 118 | super().__init__() 119 | self.mid = model_len//2 120 | self.fc = nn.Linear(in_dim, num_fc) 121 | self.out = nn.Linear(num_fc,1) 122 | 123 | def forward(self, x): 124 | x = F.relu(self.fc(x[:,self.mid, :])) 125 | x=self.out(x) 126 | return x 127 | 128 | class ClassifierAll(nn.Module): 129 | def __init__(self, in_dim: int, num_fc: int): 130 | super().__init__() 131 | self.fc = nn.Linear(in_dim, num_fc) 132 | self.out = nn.Linear(num_fc,1) 133 | 134 | def forward(self, x): 135 | x = torch.flatten(x, 1) 136 | x = F.relu(self.fc(x)) 137 | x=self.out(x) 138 | return x 139 | 140 | class BiLSTM(nn.Module): 141 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, fc_type): 142 | super(BiLSTM, self).__init__() 143 | 144 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref) 145 | self.model_len=model_dims[0] 146 | self.model_depth=model_dims[1]+self.emb.embedding_depth 147 | 148 | self.bilstm = nn.LSTM(input_size=self.model_depth, hidden_size=dim_feedforward, num_layers=num_layers, bidirectional=True, batch_first = True) 149 | 150 | if fc_type=='middle': 151 | self.classifier=ClassifierMiddle(in_dim=dim_feedforward*2, num_fc=num_fc, model_len=self.model_len) 152 | 153 | else: 154 | self.classifier=ClassifierAll(in_dim=self.model_len*dim_feedforward*2, num_fc=num_fc) 155 | 156 | def forward(self, batch_x, batch_base_seq, batch_ref_seq): 157 | seq_emb=self.emb(batch_base_seq, batch_ref_seq) 158 | x=torch.cat((batch_x, seq_emb), 2) 159 | x, _=self.bilstm(x) 160 | x = self.classifier(x) 161 | x=torch.nn.functional.sigmoid(x) 162 | return x 163 | 164 | class TransformerModel(nn.Module): 165 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, pe_dim, nhead, pe_type, fc_type): 166 | super(TransformerModel, self).__init__() 167 | 168 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref) 169 | self.model_len=model_dims[0] 170 | 171 | if pe_type=='fixed': 172 | self.pe_block=PositionalEncoding(pe_dim=pe_dim, max_len=self.model_len) 173 | 174 | elif pe_type=='embedding': 175 | self.pe_block=PositionalEmbedding(pe_dim=pe_dim, max_len=self.model_len) 176 | 177 | elif pe_type=='parameter': 178 | self.pe_block=PositionalParameter(pe_dim=pe_dim, max_len=self.model_len) 179 | 180 | self.model_depth=model_dims[1]+self.emb.embedding_depth+pe_dim 181 | self.pad_length=math.ceil(self.model_depth/nhead)*nhead-self.model_depth 182 | pad=torch.zeros(1,self.model_len, self.pad_length) 183 | self.register_buffer('pad', pad) 184 | self.model_depth+=self.pad_length 185 | 186 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_depth, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0, batch_first = True) 187 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 188 | 189 | if fc_type=='middle': 190 | self.classifier=ClassifierMiddle(in_dim=self.model_depth, num_fc=num_fc, model_len=self.model_len) 191 | 192 | else: 193 | self.classifier=ClassifierAll(in_dim=self.model_len*self.model_depth, num_fc=num_fc) 194 | 195 | def forward(self, batch_x, batch_base_seq, batch_ref_seq): 196 | seq_emb=self.emb(batch_base_seq, batch_ref_seq) 197 | x=torch.cat((batch_x, seq_emb), 2) 198 | x=self.pe_block(x) 199 | x_pad=torch.Tensor.repeat(self.pad,(x.size(0),1,1)) 200 | x = torch.cat((x, x_pad),2) 201 | 202 | x=self.transformer_encoder(x) 203 | x = self.classifier(x) 204 | x=torch.nn.functional.sigmoid(x) 205 | return x -------------------------------------------------------------------------------- /src/models/bilstm.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm.cfg -------------------------------------------------------------------------------- /src/models/bilstm/R10.4.1_4kHz_v3.5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_4kHz_v3.5 -------------------------------------------------------------------------------- /src/models/bilstm/R10.4.1_4kHz_v4.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_4kHz_v4.1 -------------------------------------------------------------------------------- /src/models/bilstm/R10.4.1_5kHz_v4.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_5kHz_v4.3 -------------------------------------------------------------------------------- /src/models/bilstm/R10.4.1_5kHz_v5.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_5kHz_v5.0 -------------------------------------------------------------------------------- /src/models/bilstm/R9.4.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R9.4.1 -------------------------------------------------------------------------------- /src/models/transformer.cfg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer.cfg -------------------------------------------------------------------------------- /src/models/transformer/R10.4.1_4kHz_v3.5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_4kHz_v3.5 -------------------------------------------------------------------------------- /src/models/transformer/R10.4.1_4kHz_v4.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_4kHz_v4.1 -------------------------------------------------------------------------------- /src/models/transformer/R10.4.1_5kHz_v4.3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_5kHz_v4.3 -------------------------------------------------------------------------------- /src/models/transformer/R10.4.1_5kHz_v5.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_5kHz_v5.0 -------------------------------------------------------------------------------- /src/models/transformer/R9.4.1: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R9.4.1 -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from subprocess import PIPE, Popen 2 | import os, shutil, pysam, sys, datetime, re, pickle 3 | import numpy as np 4 | from numba import jit 5 | import torch 6 | from itertools import repeat 7 | import pysam 8 | from .models import * 9 | import torch.nn.utils.prune as prune 10 | from tqdm import tqdm 11 | 12 | model_dict={ 13 | 14 | 'bilstm_r9.4.1' : { 'path' : 'models/bilstm/R9.4.1', 15 | 'help':'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R9.4.1 flowcells.', 16 | 'model_config_path':'models/bilstm.cfg'}, 17 | 18 | 'bilstm_r10.4.1_4khz_v3.5': { 'path' : 'models/bilstm/R10.4.1_4kHz_v3.5', 19 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v3.5 Guppy/Dorado basecaller model.', 20 | 'model_config_path':'models/bilstm.cfg'}, 21 | 22 | 'bilstm_r10.4.1_4khz_v4.1': { 'path' : 'models/bilstm/R10.4.1_4kHz_v4.1', 23 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v4.1 Guppy/Dorado basecaller model.', 24 | 'model_config_path':'models/bilstm.cfg'}, 25 | 26 | 'bilstm_r10.4.1_5khz_v4.3': { 'path' : 'models/bilstm/R10.4.1_5kHz_v4.3', 27 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v4.3 Guppy/Dorado basecaller model.', 28 | 'model_config_path':'models/bilstm.cfg'}, 29 | 30 | 'bilstm_r10.4.1_5khz_v5.0': { 'path' : 'models/bilstm/R10.4.1_5kHz_v5.0', 31 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v5.0 Guppy/Dorado basecaller model.', 32 | 'model_config_path':'models/bilstm.cfg'}, 33 | 34 | 35 | 'transformer_r9.4.1' : { 'path' : 'models/transformer/R9.4.1', 36 | 'help':'Transformer model trained on chr2-21 of HG002, HG003 and HG004 R9.4.1 flowcells.', 37 | 'model_config_path':'models/transformer.cfg'}, 38 | 39 | 'transformer_r10.4.1_4khz_v3.5': { 'path' : 'models/transformer/R10.4.1_4kHz_v3.5', 40 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v3.5 Guppy/Dorado basecaller model.', 41 | 'model_config_path':'models/transformer.cfg'}, 42 | 43 | 'transformer_r10.4.1_4khz_v4.1': { 'path' : 'models/transformer/R10.4.1_4kHz_v4.1', 44 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v4.1 Guppy/Dorado basecaller model.', 45 | 'model_config_path':'models/transformer.cfg'}, 46 | 47 | 'transformer_r10.4.1_5khz_v4.3': { 'path' : 'models/transformer/R10.4.1_5kHz_v4.3', 48 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v4.3 Guppy/Dorado basecaller model.', 49 | 'model_config_path':'models/transformer.cfg'}, 50 | 51 | 'transformer_r10.4.1_5khz_v5.0': { 'path' : 'models/transformer/R10.4.1_5kHz_v5.0', 52 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v5.0 Guppy/Dorado basecaller model.', 53 | 'model_config_path':'models/transformer.cfg'}, 54 | } 55 | 56 | comp_base_map={'A':'T','T':'A','C':'G','G':'C','[':']', ']':'['} 57 | 58 | def revcomp(s): 59 | return ''.join(comp_base_map[x] for x in s[::-1]) 60 | 61 | def get_model_help(): 62 | for n,model in enumerate(model_dict): 63 | print('-'*30) 64 | print('%d) Model Name: %s' %(n+1, model)) 65 | print('Details: %s\n' %model_dict[model]['help']) 66 | 67 | def get_model(params): 68 | model_name=params['model'] 69 | model_config_path=None 70 | model_path=None 71 | 72 | if model_name in model_dict: 73 | dirname = os.path.dirname(__file__) 74 | model_info=model_dict[model_name] 75 | model_config_path = os.path.join(dirname, model_info['model_config_path']) 76 | model_path = os.path.join(dirname, model_info['path']) 77 | 78 | else: 79 | try: 80 | model_config_path = model_name.split(',')[0] 81 | model_path = model_name.split(',')[1] 82 | 83 | except IndexError: 84 | print('Incorrect model specified') 85 | sys.exit(2) 86 | 87 | with open(model_config_path, 'rb') as handle: 88 | model_config = pickle.load(handle) 89 | 90 | if model_config['model_type']=='bilstm': 91 | model = BiLSTM(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \ 92 | dim_feedforward=model_config['dim_feedforward'], \ 93 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \ 94 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'], \ 95 | fc_type=model_config['fc_type']); 96 | 97 | checkpoint = torch.load(model_path, map_location ='cpu') 98 | model.load_state_dict(checkpoint['model_state_dict']) 99 | 100 | if not params['disable_pruning']: 101 | module=model.classifier.fc 102 | prune.l1_unstructured(module, name="weight", amount=0.95) 103 | prune.remove(module, 'weight') 104 | 105 | return model, model_config 106 | 107 | elif model_config['model_type']=='transformer': 108 | model = TransformerModel(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \ 109 | dim_feedforward=model_config['dim_feedforward'], \ 110 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \ 111 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'],\ 112 | pe_dim=model_config['pe_dim'], nhead=model_config['nhead'], \ 113 | pe_type=model_config['pe_type'], fc_type=model_config['fc_type']); 114 | 115 | checkpoint = torch.load(model_path, map_location ='cpu') 116 | model.load_state_dict(checkpoint['model_state_dict']) 117 | 118 | if not params['disable_pruning']: 119 | module=model.classifier.fc 120 | prune.l1_unstructured(module, name="weight", amount=0.5) 121 | prune.remove(module, 'weight') 122 | for l in model.transformer_encoder.layers: 123 | module=l.linear1 124 | prune.l1_unstructured(module, name="weight", amount=0.25) 125 | prune.remove(module, 'weight') 126 | module=l.linear2 127 | prune.l1_unstructured(module, name="weight", amount=0.25) 128 | prune.remove(module, 'weight') 129 | module=l.self_attn.out_proj 130 | prune.l1_unstructured(module, name="weight", amount=0.25) 131 | prune.remove(module, 'weight') 132 | 133 | return model, model_config 134 | 135 | else: 136 | print('Model: %s not found.' %params['model'], flush=True) 137 | sys.exit(2) 138 | 139 | def generate_batches(features, base_seq, window, ref_seq=None, batch_size=512): 140 | if len(ref_seq)==0: 141 | ref_seq=(4+torch.zeros(features.shape[0], 2*window+1)).type(torch.LongTensor) 142 | else: 143 | ref_seq=torch.Tensor(ref_seq).type(torch.LongTensor) 144 | 145 | features=torch.Tensor(features) 146 | base_seq=torch.Tensor(base_seq).type(torch.LongTensor) 147 | 148 | for local_index in range(0, features.shape[0], batch_size): 149 | batch_x=features[local_index:(local_index + batch_size)] 150 | batch_base_seq=base_seq[local_index:(local_index + batch_size)] 151 | batch_ref_seq=ref_seq[local_index:(local_index + batch_size)] 152 | 153 | yield batch_x, batch_base_seq, batch_ref_seq 154 | 155 | @jit(nopython=True) 156 | def get_aligned_pairs(cigar_tuples, ref_start): 157 | alen=np.sum(cigar_tuples[:,0]) 158 | pairs=np.zeros((alen,2)).astype(np.int32) 159 | 160 | i=0 161 | ref_cord=ref_start-1 162 | read_cord=-1 163 | pair_cord=0 164 | for i in range(len(cigar_tuples)): 165 | len_op, op= cigar_tuples[i,0], cigar_tuples[i,1] 166 | if op==0: 167 | for k in range(len_op): 168 | ref_cord+=1 169 | read_cord+=1 170 | 171 | pairs[pair_cord,0]=read_cord 172 | pairs[pair_cord,1]=ref_cord 173 | pair_cord+=1 174 | 175 | elif op==2: 176 | for k in range(len_op): 177 | read_cord+=1 178 | pairs[pair_cord,0]=read_cord 179 | pairs[pair_cord,1]=-1 180 | pair_cord+=1 181 | 182 | elif op==1: 183 | for k in range(len_op): 184 | ref_cord+=1 185 | pairs[pair_cord,0]=-1 186 | pairs[pair_cord,1]=ref_cord 187 | pair_cord+=1 188 | return pairs 189 | 190 | @jit(nopython=True) 191 | def get_ref_to_num(x): 192 | b=np.full((len(x)+1,2),fill_value=0,dtype=np.int8) 193 | 194 | for i,l in enumerate(x): 195 | if l=='A': 196 | b[i,0]=0 197 | b[i,1]=3 198 | 199 | elif l=='T': 200 | b[i,0]=3 201 | b[i,1]=0 202 | 203 | elif l=='C': 204 | b[i,0]=1 205 | b[i,1]=2 206 | 207 | elif l=='G': 208 | b[i,0]=2 209 | b[i,1]=1 210 | 211 | else: 212 | b[i,0]=4 213 | b[i,1]=4 214 | 215 | b[-1,0]=4 216 | b[-1,1]=4 217 | 218 | return b 219 | 220 | def get_pos(path): 221 | labelled_pos_list={} 222 | strand_map={'+':0, '-':1} 223 | 224 | with open(path) as file: 225 | for line in file: 226 | line=line.rstrip('\n').split('\t') 227 | if line[0] not in labelled_pos_list: 228 | labelled_pos_list[line[0]]={0:[], 1:[]} 229 | 230 | labelled_pos_list[line[0]][strand_map[line[2]]].append(int(line[1])) 231 | 232 | return labelled_pos_list 233 | 234 | def motif_check(motif): 235 | nt_dict={'R': 'GA', 236 | 'Y': 'CT', 237 | 'K': 'GT', 238 | 'M': 'AC', 239 | 'S': 'GC', 240 | 'W': 'AT', 241 | 'B': 'GTC', 242 | 'D': 'GAT', 243 | 'H': 'ACT', 244 | 'V': 'GCA', 245 | 'N': 'AGCT'} 246 | 247 | valid_alphabet=set(nt_dict.keys()).union({'A', 'C', 'G', 'T'}) 248 | 249 | exp_motif_seq, final_motif_ind, valid = None, None, False 250 | 251 | if len(motif)<2: 252 | print('--motif not specified correctly. You need to specify a motif and at least one index',flush=True) 253 | return motif_seq, exp_motif_seq, final_motif_ind, valid 254 | 255 | elif len(set(motif[0])-valid_alphabet)>0: 256 | print('--motif not specified correctly. Motif should only consist of the following extended nucleotide letters: {}'.format(','.join(valid_alphabet)),flush=True) 257 | return motif_seq, exp_motif_seq, final_motif_ind, valid 258 | 259 | elif all([a.isnumeric() for a in motif[1:]])==False: 260 | print('--motif not specified correctly. Motif indices should be integers separated by whitespace and shoud come after the motif sequence.',flush=True) 261 | return motif_seq, exp_motif_seq, final_motif_ind, valid 262 | 263 | else: 264 | motif_seq=motif[0] 265 | motif_ind=[int(x) for x in motif[1:]] 266 | 267 | if len(set(motif_seq[x] for x in motif_ind))!=1 or len(set(motif_seq[x] for x in motif_ind)-set('ACGT'))>0: 268 | print('Base of interest should be same for all indices and must be one of A, C, G or T.', flush=True) 269 | return motif_seq, exp_motif_seq, final_motif_ind, valid 270 | 271 | else: 272 | exp_motif_seq=motif_seq 273 | for nt in nt_dict: 274 | if nt in exp_motif_seq: 275 | exp_motif_seq=exp_motif_seq.replace(nt, '[{}]'.format(nt_dict[nt])) 276 | return motif_seq, exp_motif_seq, motif_ind, True 277 | 278 | def get_ref_info(args): 279 | params, chrom=args 280 | motif_seq, motif_ind=params['motif_seq'], params['motif_ind'] 281 | exp_motif_seq=params['exp_motif_seq'] 282 | ref_fasta=pysam.FastaFile(params['ref']) 283 | seq=ref_fasta.fetch(chrom).upper() 284 | seq_array=get_ref_to_num(seq) 285 | 286 | 287 | fwd_motif_anchor=np.array([m.start(0) for m in re.finditer(r'(?={})'.format(exp_motif_seq), seq)]) 288 | rev_motif_anchor=np.array([m.start(0) for m in re.finditer(r'(?={})'.format(revcomp(exp_motif_seq)), seq)]) 289 | 290 | fwd_pos_array=np.array(sorted(list(set.union(*[set(fwd_motif_anchor+i) for i in motif_ind])))).astype(int) 291 | rev_pos_array=np.array(sorted(list(set.union(*[set(rev_motif_anchor+len(motif_seq)-1-i) for i in motif_ind])))).astype(int) 292 | 293 | return chrom, seq_array, fwd_pos_array, rev_pos_array 294 | 295 | def get_stats_string_cpg(chrom, pos, is_ref_cpg, cpg): 296 | unphased_rev_unmod, unphased_rev_mod, unphased_fwd_unmod, unphased_fwd_mod=cpg[0:4] 297 | phase1_rev_unmod, phase1_rev_mod, phase1_fwd_unmod, phase1_fwd_mod=cpg[4:8] 298 | phase2_rev_unmod, phase2_rev_mod, phase2_fwd_unmod, phase2_fwd_mod=cpg[8:12] 299 | 300 | fwd_mod=unphased_fwd_mod+phase1_fwd_mod+phase2_fwd_mod 301 | fwd_unmod=unphased_fwd_unmod+phase1_fwd_unmod+phase2_fwd_unmod 302 | fwd_total_stats=[fwd_mod+fwd_unmod,fwd_mod,fwd_unmod,fwd_mod/(fwd_mod+fwd_unmod) if fwd_mod+fwd_unmod>0 else 0] 303 | fwd_phase1_stats=[phase1_fwd_mod+phase1_fwd_unmod, phase1_fwd_mod, phase1_fwd_unmod, phase1_fwd_mod/(phase1_fwd_mod+phase1_fwd_unmod) if phase1_fwd_mod+phase1_fwd_unmod>0 else 0] 304 | fwd_phase2_stats=[phase2_fwd_mod+phase2_fwd_unmod, phase2_fwd_mod, phase2_fwd_unmod, phase2_fwd_mod/(phase2_fwd_mod+phase2_fwd_unmod) if phase2_fwd_mod+phase2_fwd_unmod>0 else 0] 305 | 306 | fwd_str='{}\t{}\t{}\t+\t{}\t'.format(chrom, pos, pos+1, is_ref_cpg)+'{}\t{}\t{}\t{:.4f}\t'.format(*fwd_total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*fwd_phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*fwd_phase2_stats) 307 | 308 | 309 | rev_mod=unphased_rev_mod+phase1_rev_mod+phase2_rev_mod 310 | rev_unmod=unphased_rev_unmod+phase1_rev_unmod+phase2_rev_unmod 311 | rev_total_stats=[rev_mod+rev_unmod,rev_mod,rev_unmod,rev_mod/(rev_mod+rev_unmod) if rev_mod+rev_unmod>0 else 0] 312 | rev_phase1_stats=[phase1_rev_mod+phase1_rev_unmod, phase1_rev_mod, phase1_rev_unmod, phase1_rev_mod/(phase1_rev_mod+phase1_rev_unmod) if phase1_rev_mod+phase1_rev_unmod>0 else 0] 313 | rev_phase2_stats=[phase2_rev_mod+phase2_rev_unmod, phase2_rev_mod, phase2_rev_unmod, phase2_rev_mod/(phase2_rev_mod+phase2_rev_unmod) if phase2_rev_mod+phase2_rev_unmod>0 else 0] 314 | 315 | rev_str='{}\t{}\t{}\t-\t{}\t'.format(chrom, pos+1, pos+2, is_ref_cpg)+'{}\t{}\t{}\t{:.4f}\t'.format(*rev_total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*rev_phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*rev_phase2_stats) 316 | 317 | 318 | agg_total_stats=[fwd_total_stats[0]+rev_total_stats[0], fwd_total_stats[1]+rev_total_stats[1], fwd_total_stats[2]+rev_total_stats[2], (fwd_total_stats[1]+rev_total_stats[1])/(fwd_total_stats[0]+rev_total_stats[0]) if fwd_total_stats[0]+rev_total_stats[0]>0 else 0] 319 | 320 | agg_phase1_stats=[fwd_phase1_stats[0]+rev_phase1_stats[0], fwd_phase1_stats[1]+rev_phase1_stats[1], fwd_phase1_stats[2]+rev_phase1_stats[2], (fwd_phase1_stats[1]+rev_phase1_stats[1])/(fwd_phase1_stats[0]+rev_phase1_stats[0]) if fwd_phase1_stats[0]+rev_phase1_stats[0]>0 else 0] 321 | 322 | agg_phase2_stats=[fwd_phase2_stats[0]+rev_phase2_stats[0], fwd_phase2_stats[1]+rev_phase2_stats[1], fwd_phase2_stats[2]+rev_phase2_stats[2], (fwd_phase2_stats[1]+rev_phase2_stats[1])/(fwd_phase2_stats[0]+rev_phase2_stats[0]) if fwd_phase2_stats[0]+rev_phase2_stats[0]>0 else 0] 323 | 324 | agg_str='{}\t{}\t{}\t{}\t'.format(chrom, pos, pos+2, is_ref_cpg)+'{}\t{}\t{}\t{:.4f}\t'.format(*agg_total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*agg_phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*agg_phase2_stats) 325 | 326 | return [(agg_total_stats[0], agg_str),(fwd_total_stats[0], fwd_str),(rev_total_stats[0], rev_str)] 327 | 328 | 329 | def get_stats_string(chrom, pos, strand, mod_call): 330 | unphased_unmod, unphased_mod=mod_call[0:2] 331 | phase1_unmod, phase1_mod=mod_call[2:4] 332 | phase2_unmod, phase2_mod=mod_call[4:6] 333 | 334 | mod=unphased_mod+phase1_mod+phase2_mod 335 | unmod=unphased_unmod+phase1_unmod+phase2_unmod 336 | 337 | total_stats=[mod+unmod, mod, unmod,mod/(mod+unmod) if mod+unmod>0 else 0] 338 | phase1_stats=[phase1_mod+phase1_unmod, phase1_mod, phase1_unmod, phase1_mod/(phase1_mod+phase1_unmod) if phase1_mod+phase1_unmod>0 else 0] 339 | phase2_stats=[phase2_mod+phase2_unmod, phase2_mod, phase2_unmod, phase2_mod/(phase2_mod+phase2_unmod) if phase2_mod+phase2_unmod>0 else 0] 340 | 341 | mod_str='{}\t{}\t{}\t{}\t'.format(chrom, pos, pos+1, strand)+'{}\t{}\t{}\t{:.4f}\t'.format(*total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*phase2_stats) 342 | 343 | return total_stats[0], mod_str 344 | 345 | def get_cpg_output(params, input_list): 346 | qscore_cutoff=params['qscore_cutoff'] 347 | length_cutoff=params['length_cutoff'] 348 | 349 | mod_threshold=params['mod_t'] 350 | unmod_threshold=params['unmod_t'] 351 | 352 | cpg_ref_only=not params['include_non_cpg_ref'] 353 | 354 | 355 | 356 | total_files=len(input_list) 357 | print('%s: Reading %d files.' %(str(datetime.datetime.now()), total_files), flush=True) 358 | pbar = tqdm(total=total_files) 359 | 360 | per_site_pred={} 361 | 362 | for read_pred_file in input_list: 363 | with open(read_pred_file,'r') as read_file: 364 | read_file.readline() 365 | for line in read_file: 366 | read, chrom, pos, pos_after, read_pos, strand, score, mean_qscore, sequence_length, phase, is_ref_cpg = line.rstrip('\n').split('\t') 367 | 368 | if pos=='NA' or float(mean_qscore)unmod_threshold: 374 | continue 375 | else: 376 | mod=score>=mod_threshold 377 | 378 | pos=int(pos) 379 | phase=int(phase) 380 | is_forward=1 if strand=='+' else 0 381 | 382 | idx=4*phase+2*is_forward 383 | 384 | is_ref_cpg=True if is_ref_cpg =='True' else False 385 | zero_based_fwd_pos=pos if strand=='+' else pos-1 386 | 387 | if (chrom, zero_based_fwd_pos) not in per_site_pred: 388 | per_site_pred[(chrom, zero_based_fwd_pos)]=[0]*12+[is_ref_cpg] 389 | 390 | per_site_pred[(chrom, zero_based_fwd_pos)][idx+mod]+=1 391 | 392 | pbar.update(1) 393 | pbar.close() 394 | 395 | print('%s: Writing Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True) 396 | 397 | per_site_fields=['#chromosome', 'position_before', 'position','strand', 'ref_cpg', 398 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction', 399 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1', 400 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2'] 401 | per_site_header='\t'.join(per_site_fields)+'\n' 402 | per_site_fields.remove('strand') 403 | agg_per_site_header='\t'.join(per_site_fields)+'\n' 404 | 405 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix']) 406 | agg_per_site_file_path=os.path.join(params['output'],'%s.per_site.aggregated' %params['prefix']) 407 | 408 | with open(per_site_file_path, 'w') as per_site_file, open(agg_per_site_file_path,'w') as agg_per_site_file: 409 | per_site_file.write(per_site_header) 410 | agg_per_site_file.write(agg_per_site_header) 411 | 412 | for x in sorted(per_site_pred.keys()): 413 | chrom, pos=x 414 | cpg=per_site_pred[x] 415 | is_ref_cpg=cpg[12] 416 | 417 | if cpg_ref_only and is_ref_cpg==False: 418 | continue 419 | #fwd_stats=[self.chrom, self.position, self.position+1, '+', self.is_ref_cpg]+self.get_all_phases().forward.stats() + self.phase_1.forward.stats() + self.phase_2.forward.stats() 420 | 421 | agg_stats, fwd_stats, rev_stats=get_stats_string_cpg(chrom, pos, is_ref_cpg, cpg) 422 | if agg_stats[0]>0: 423 | agg_per_site_file.write(agg_stats[1]) 424 | 425 | if fwd_stats[0]>0: 426 | per_site_file.write(fwd_stats[1]) 427 | 428 | if rev_stats[0]>0: 429 | per_site_file.write(rev_stats[1]) 430 | 431 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True) 432 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True) 433 | print('%s: Aggregated Per Site Prediction file: %s' %(str(datetime.datetime.now()), agg_per_site_file_path), flush=True) 434 | 435 | def get_output(params, input_list): 436 | qscore_cutoff=params['qscore_cutoff'] 437 | length_cutoff=params['length_cutoff'] 438 | 439 | mod_threshold=params['mod_t'] 440 | unmod_threshold=params['unmod_t'] 441 | 442 | cpg_ref_only=not params['include_non_cpg_ref'] 443 | 444 | 445 | 446 | total_files=len(input_list) 447 | print('%s: Reading %d files.' %(str(datetime.datetime.now()), total_files), flush=True) 448 | pbar = tqdm(total=total_files) 449 | 450 | per_site_pred={} 451 | 452 | for read_pred_file in input_list: 453 | with open(read_pred_file,'r') as read_file: 454 | read_file.readline() 455 | for line in read_file: 456 | read, chrom, pos, pos_after, read_pos, strand, score, mean_qscore, sequence_length, phase = line.rstrip('\n').split('\t') 457 | 458 | if pos=='NA' or float(mean_qscore)unmod_threshold: 464 | continue 465 | else: 466 | mod=score>=mod_threshold 467 | 468 | pos=int(pos) 469 | phase=int(phase) 470 | 471 | if (chrom, pos,strand) not in per_site_pred: 472 | per_site_pred[(chrom, pos,strand)]=[0]*6 473 | 474 | per_site_pred[(chrom, pos,strand)][2*phase+mod]+=1 475 | 476 | pbar.update(1) 477 | pbar.close() 478 | 479 | print('%s: Writing Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True) 480 | 481 | per_site_fields=['#chromosome', 'position_before', 'position','strand', 482 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction', 483 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1', 484 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2'] 485 | per_site_header='\t'.join(per_site_fields)+'\n' 486 | 487 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix']) 488 | 489 | with open(per_site_file_path, 'w') as per_site_file: 490 | per_site_file.write(per_site_header) 491 | 492 | for x in sorted(per_site_pred.keys()): 493 | chrom, pos, strand=x 494 | mod_call=per_site_pred[x] 495 | 496 | stats=get_stats_string(chrom, pos, strand, mod_call) 497 | if stats[0]>0: 498 | per_site_file.write(stats[1]) 499 | 500 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True) 501 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True) 502 | 503 | def get_per_site(params, input_list): 504 | print('%s: Starting Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True) 505 | 506 | if params['cpg_output']: 507 | get_cpg_output(params, input_list) 508 | else: 509 | get_output(params, input_list) 510 | 511 | -------------------------------------------------------------------------------- /train/generate_features.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, ChainMap 2 | import time, itertools, h5py, pysam 3 | import datetime, os, shutil, argparse, sys, re, array 4 | import os 5 | from itertools import repeat 6 | import multiprocessing as mp 7 | import numpy as np 8 | from pathlib import Path 9 | from ont_fast5_api.fast5_interface import get_fast5_file 10 | from numba import jit 11 | import queue, gzip 12 | import pod5 as p5 13 | import utils 14 | 15 | base_to_num_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3,'N':4} 16 | 17 | num_to_base_map={0:'A', 1:'C', 2:'G', 3:'T', 4:'N'} 18 | 19 | comp_base_map={'A':'T','T':'A','C':'G','G':'C','[':']', ']':'['} 20 | 21 | def revcomp(s): 22 | return ''.join(comp_base_map[x] for x in s[::-1]) 23 | 24 | def motif_check(motif): 25 | nt_dict={'R': 'GA', 26 | 'Y': 'CT', 27 | 'K': 'GT', 28 | 'M': 'AC', 29 | 'S': 'GC', 30 | 'W': 'AT', 31 | 'B': 'GTC', 32 | 'D': 'GAT', 33 | 'H': 'ACT', 34 | 'V': 'GCA', 35 | 'N': 'AGCT'} 36 | 37 | valid_alphabet=set(nt_dict.keys()).union({'A', 'C', 'G', 'T'}) 38 | 39 | exp_motif_seq, final_motif_ind, valid = None, None, False 40 | 41 | if len(motif)<2: 42 | print('--motif not specified correctly. You need to specify a motif and at least one index',flush=True) 43 | return motif_seq, exp_motif_seq, final_motif_ind, valid 44 | 45 | elif len(set(motif[0])-valid_alphabet)>0: 46 | print('--motif not specified correctly. Motif should only consist of the following extended nucleotide letters: {}'.format(','.join(valid_alphabet)),flush=True) 47 | return motif_seq, exp_motif_seq, final_motif_ind, valid 48 | 49 | elif all([a.isnumeric() for a in motif[1:]])==False: 50 | print('--motif not specified correctly. Motif indices should be integers separated by whitespace and shoud come after the motif sequence.',flush=True) 51 | return motif_seq, exp_motif_seq, final_motif_ind, valid 52 | 53 | else: 54 | motif_seq=motif[0] 55 | motif_ind=[int(x) for x in motif[1:]] 56 | 57 | if len(set(motif_seq[x] for x in motif_ind))!=1 or len(set(motif_seq[x] for x in motif_ind)-set('ACGT'))>0: 58 | print('Base of interest should be same for all indices and must be one of A, C, G or T.', flush=True) 59 | return motif_seq, exp_motif_seq, final_motif_ind, valid 60 | 61 | else: 62 | exp_motif_seq=motif_seq 63 | for nt in nt_dict: 64 | if nt in exp_motif_seq: 65 | exp_motif_seq=exp_motif_seq.replace(nt, '[{}]'.format(nt_dict[nt])) 66 | return motif_seq, exp_motif_seq, motif_ind, True 67 | 68 | def get_candidates(read_seq, align_data, aligned_pairs, ref_pos_dict, ): 69 | is_mapped, is_forward, ref_name, reference_start, reference_end, read_length=align_data 70 | 71 | ref_motif_pos=ref_pos_dict[ref_name][0] if is_forward else ref_pos_dict[ref_name][1] 72 | 73 | common_pos=ref_motif_pos[(ref_motif_pos>=reference_start)&(ref_motif_pos5]=5 187 | signal[signal<-5]=-5 188 | 189 | move_len=len(move_table) 190 | move_index=np.where(move_table)[0] 191 | rlen=len(move_index) 192 | 193 | data=np.zeros((rlen,9)) 194 | 195 | for i in range(len(move_index)-1): 196 | prev=move_index[i]*stride+start 197 | sig_end=move_index[i+1]*stride+start 198 | 199 | sig_len=sig_end-prev 200 | data[i, 8]=np.log10(sig_len) 201 | data[i, 4]=np.median(signal[prev:sig_end]) 202 | data[i, 5]=np.median(np.abs(signal[prev:sig_end]-data[i, 4])) 203 | data[i, 6]=np.mean(signal[prev:sig_end]) 204 | data[i, 7]=np.std(signal[prev:sig_end]) 205 | 206 | for j in range(4): 207 | tmp_cnt=0 208 | for t in range(j*sig_len//4,min(sig_len, (j+1)*sig_len//4)): 209 | data[i, j]+=signal[t+prev] 210 | tmp_cnt+=1 211 | data[i, j]=data[i, j]/tmp_cnt 212 | 213 | return data 214 | 215 | def get_pos(path): 216 | labelled_pos_list={} 217 | strand_map={'+':0, '-':1} 218 | 219 | with open(path) as file: 220 | for line in file: 221 | line=line.rstrip('\n').split('\t') 222 | if line[0] not in labelled_pos_list: 223 | labelled_pos_list[line[0]]={0:{}, 1:{}} 224 | 225 | labelled_pos_list[line[0]][strand_map[line[2]]][int(line[1])]=float(line[3]) 226 | 227 | return labelled_pos_list 228 | 229 | def write_to_npz(output_file_path, mat, base_qual, base_seq, ref_seq, label, ref_coordinates, read_name, ref_name, window, norm_type): 230 | np.savez(output_file_path, mat=mat, base_qual=base_qual, base_seq=base_seq, ref_seq=ref_seq, label=label, ref_coordinates=ref_coordinates, read_name=read_name, ref_name=ref_name, window=window, norm_type=norm_type, strides_per_base=1, model_depth=mat.shape[2], full_signal=False) 231 | 232 | def get_output(params, output_Q, process_event): 233 | output=params['output'] 234 | 235 | reads_per_chunk=params['reads_per_chunk'] 236 | 237 | chunk=1 238 | read_count=0 239 | 240 | output_file_path=os.path.join(output,'%s.features.%d.npz' %(params['prefix'], chunk)) 241 | 242 | mat, base_qual, base_seq, ref_seq, label=[], [], [], [], [] 243 | ref_coordinates, read_name, ref_name= [], [], [] 244 | 245 | while True: 246 | if process_event.is_set() and output_Q.empty(): 247 | break 248 | else: 249 | try: 250 | res = output_Q.get(block=False) 251 | #per_site_features, per_site_base_qual, per_site_base_seq, per_site_ref_seq, per_site_ref_coordinates, per_site_label, read_name_array, ref_name_array 252 | 253 | mat.append(res[0]) 254 | base_qual.append(res[1]) 255 | base_seq.append(res[2]) 256 | ref_seq.append(res[3]) 257 | ref_coordinates.append(res[4]) 258 | label.append(res[5]) 259 | read_name.append(res[6]) 260 | ref_name.append(res[7]) 261 | 262 | read_count+=1 263 | 264 | if read_count%reads_per_chunk==0 and len(mat)>0: 265 | mat=np.vstack(mat) 266 | base_qual=np.vstack(base_qual) 267 | base_seq=np.vstack(base_seq).astype(np.int8) 268 | ref_seq=np.vstack(ref_seq).astype(np.int8) 269 | label=np.hstack(label).astype(np.float16) 270 | ref_coordinates=np.hstack(ref_coordinates) 271 | read_name=np.hstack(read_name) 272 | ref_name=np.hstack(ref_name) 273 | 274 | idx=np.random.permutation(np.arange(len(label))) 275 | mat=mat[idx] 276 | base_qual=base_qual[idx] 277 | base_seq=base_seq[idx] 278 | ref_seq=ref_seq[idx] 279 | label=label[idx] 280 | ref_coordinates=ref_coordinates[idx] 281 | read_name=read_name[idx] 282 | ref_name=ref_name[idx] 283 | 284 | print('%s: Number of reads processed = %d.' %(str(datetime.datetime.now()), read_count), flush=True) 285 | 286 | 287 | write_to_npz(output_file_path, mat, base_qual, base_seq, ref_seq, label, ref_coordinates, read_name, ref_name, window=params['window'], norm_type=params['norm_type']) 288 | 289 | chunk+=1 290 | output_file_path=os.path.join(output,'%s.features.%d.npz' %(params['prefix'], chunk)) 291 | mat, base_qual, base_seq, ref_seq, label=[], [], [], [], [] 292 | ref_coordinates, read_name, ref_name= [], [], [] 293 | 294 | except queue.Empty: 295 | pass 296 | 297 | if read_count>0 and len(mat)>0: 298 | mat=np.vstack(mat) 299 | base_qual=np.vstack(base_qual) 300 | base_seq=np.vstack(base_seq).astype(np.int8) 301 | ref_seq=np.vstack(ref_seq).astype(np.int8) 302 | label=np.hstack(label).astype(np.float16) 303 | ref_coordinates=np.hstack(ref_coordinates) 304 | read_name=np.hstack(read_name) 305 | ref_name=np.hstack(ref_name) 306 | 307 | idx=np.random.permutation(np.arange(len(label))) 308 | mat=mat[idx] 309 | base_qual=base_qual[idx] 310 | base_seq=base_seq[idx] 311 | ref_seq=ref_seq[idx] 312 | label=label[idx] 313 | ref_coordinates=ref_coordinates[idx] 314 | read_name=read_name[idx] 315 | ref_name=ref_name[idx] 316 | 317 | print('%s: Number of reads processed = %d.' %(str(datetime.datetime.now()), read_count), flush=True) 318 | 319 | 320 | write_to_npz(output_file_path, mat, base_qual, base_seq, ref_seq, label, ref_coordinates, read_name, ref_name, params['window'], params['norm_type']) 321 | 322 | return 323 | 324 | def process(params, ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict, labelled_pos_list): 325 | base_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3} 326 | 327 | window=params['window'] 328 | window_range=np.arange(-window,window+1) 329 | norm_type=params['norm_type'] 330 | 331 | div_threshold=params['div_threshold'] 332 | cigar_map={'M':0, '=':0, 'X':0, 'D':1, 'I':2, 'S':2,'H':2, 'N':1, 'P':4, 'B':4} 333 | cigar_pattern = r'\d+[A-Za-z]' 334 | 335 | seq_type=params['seq_type'] 336 | ref_available=True if params['ref'] else False 337 | 338 | while True: 339 | if (signal_Q.empty() and input_event.is_set()): 340 | break 341 | 342 | try: 343 | data=signal_Q.get(block=False) 344 | signal, move, read_dict, align_data=data 345 | 346 | is_mapped, is_forward, ref_name, reference_start,reference_end, read_length=align_data 347 | 348 | fq=read_dict['seq'] 349 | qual=read_dict['qual'] 350 | sequence_length=len(fq) 351 | reverse= not is_forward 352 | fq=revcomp(fq) if reverse else fq 353 | qual=qual[::-1] if reverse else qual 354 | 355 | if is_mapped and True: 356 | cigar_tuples = np.array([(int(x[:-1]), cigar_map[x[-1]]) for x in re.findall(cigar_pattern, read_dict['cigar'])]) 357 | ref_start=int(read_dict['ref_pos'])-1 358 | aligned_pairs=get_aligned_pairs(cigar_tuples, ref_start) 359 | else: 360 | continue 361 | 362 | init_pos_list_candidates, read_to_ref_pairs=get_candidates(fq, align_data, aligned_pairs, ref_pos_dict) 363 | init_pos_list_candidates=init_pos_list_candidates[(init_pos_list_candidates[:,0]>window)\ 364 | &(init_pos_list_candidates[:,0]0 else init_pos_list_candidates 365 | 366 | if len(init_pos_list_candidates)==0: 367 | 368 | continue 369 | 370 | base_seq=np.array([base_map[x] for x in fq]) 371 | ref_seq=ref_seq_dict[ref_name][:,1][read_to_ref_pairs[:, 1]][::-1] if reverse else ref_seq_dict[ref_name][:,0][read_to_ref_pairs[:, 1]] 372 | 373 | 374 | label_filter_idx=np.array([np.mean(ref_seq[candidate[0]-window: candidate[0]+window+1]!=\ 375 | base_seq[candidate[0]-window: candidate[0]+window+1])<=div_threshold \ 376 | for candidate in init_pos_list_candidates]) 377 | pos_list_candidates=init_pos_list_candidates[label_filter_idx] 378 | 379 | 380 | if len(pos_list_candidates)==0: 381 | continue 382 | 383 | if not move[0]: 384 | try: 385 | tags={x.split(':')[0]:x for x in read_dict.pop('tags')} 386 | start=int(tags['ts'].split(':')[-1]) 387 | mv=tags['mv'].split(',') 388 | 389 | stride=int(mv[1]) 390 | move_table=np.fromiter(mv[2:], dtype=np.int8) 391 | move=(stride, start, move_table) 392 | read_dict['tags']=[x for x in tags.values() if x[:2] not in ['mv', 'ts', 'ML', 'MM']] 393 | except KeyError: 394 | print('Read:%s no move table or stride or signal start found' %read_dict['name']) 395 | total_unprocessed_reads.append(read_dict) 396 | continue 397 | 398 | base_qual=10**((33-np.array([ord(x) for x in qual]))/10) 399 | mean_qscore=-10*np.log10(np.mean(base_qual)) 400 | base_qual=(1-base_qual) 401 | 402 | mat=get_events(signal, move, norm_type) 403 | 404 | if seq_type=='rna': 405 | mat=np.flip(mat,axis=0) 406 | 407 | per_site_features=np.array([mat[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 408 | per_site_base_qual=np.array([base_qual[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 409 | per_site_base_seq=np.array([base_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 410 | per_site_ref_seq=np.array([ref_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates]) 411 | per_site_ref_coordinates=pos_list_candidates[:,1] 412 | per_site_label=np.array([labelled_pos_list[ref_name][1-is_forward][coord] for coord in per_site_ref_coordinates]) 413 | read_name_array=np.array([read_dict['name'] for candidate in pos_list_candidates]) 414 | ref_name_array=np.array([ref_name for candidate in pos_list_candidates]) 415 | 416 | read_chunks=[per_site_features, per_site_base_qual, per_site_base_seq, per_site_ref_seq, per_site_ref_coordinates, per_site_label, read_name_array, ref_name_array] 417 | 418 | output_Q.put(read_chunks) 419 | 420 | 421 | except queue.Empty: 422 | pass 423 | 424 | return 425 | 426 | def get_input(params, signal_Q, output_Q, input_event): 427 | chrom_list=params['chrom'] 428 | 429 | length_cutoff=params['length_cutoff'] 430 | 431 | bam=params['bam'] 432 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False) 433 | 434 | print('%s: Building BAM index.' %str(datetime.datetime.now()), flush=True) 435 | bam_index=pysam.IndexedReads(bam_file) 436 | bam_index.build() 437 | print('%s: Finished building BAM index.' %str(datetime.datetime.now()), flush=True) 438 | 439 | input_=params['input'] 440 | signal_files= [input_] if os.path.isfile(input_) else Path(input_).rglob("*.%s" %params['file_type']) 441 | 442 | counter=0 443 | if params['file_type']=='fast5': 444 | guppy_group=params['guppy_group'] 445 | for filename in signal_files: 446 | with get_fast5_file(filename, mode="r") as f5: 447 | for read in f5.get_reads(): 448 | counter+=1 449 | if counter%10000==0: 450 | print('%s: Number of reads read = %d.' %(str(datetime.datetime.now()), counter), flush=True) 451 | if signal_Q.qsize()>10000: 452 | time.sleep(10) 453 | 454 | read_name=read.read_id 455 | try: 456 | read_iter=bam_index.find(read_name) 457 | for bam_read in read_iter: 458 | 459 | if bam_read.flag & 0x900==0 and bam_read.reference_name in chrom_list and bam_read.query_length>=length_cutoff: 460 | read_dict=bam_read.to_dict() 461 | signal=read.get_raw_data() 462 | 463 | if params['fast5_move']: 464 | segment=read.get_analysis_attributes(guppy_group)['segmentation'] 465 | start=read.get_analysis_attributes('%s/Summary/segmentation' %segment)['first_sample_template'] 466 | stride=read.get_summary_data(guppy_group)['basecall_1d_template']['block_stride'] 467 | move_table=read.get_analysis_dataset('%s/BaseCalled_template' %guppy_group, 'Move') 468 | move=(stride, start, move_table) 469 | else: 470 | move=(None,None,None) 471 | 472 | align_data=(bam_read.is_mapped if params['ref'] else False, 473 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length) 474 | data=(signal, move, read_dict, align_data) 475 | signal_Q.put(data) 476 | 477 | except KeyError as error: 478 | continue 479 | 480 | else: 481 | move=(None,None,None) 482 | for filename in signal_files: 483 | with p5.Reader(filename) as reader: 484 | for read in reader.reads(): 485 | counter+=1 486 | if counter%10000==0: 487 | print('%s: Number of reads processed = %d.' %(str(datetime.datetime.now()), counter), flush=True) 488 | 489 | if signal_Q.qsize()>10000: 490 | time.sleep(10) 491 | 492 | read_name=str(read.read_id) 493 | try: 494 | read_iter=bam_index.find(read_name) 495 | for bam_read in read_iter: 496 | read_dict=bam_read.to_dict() 497 | if bam_read.flag & 0x900==0 and bam_read.reference_name in chrom_list and bam_read.query_length>=length_cutoff: 498 | read_dict=bam_read.to_dict() 499 | signal=read.signal 500 | align_data=(bam_read.is_mapped if params['ref'] else False, 501 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length) 502 | data=(signal, move, read_dict, align_data) 503 | signal_Q.put(data) 504 | 505 | except KeyError: 506 | continue 507 | 508 | input_event.set() 509 | return 510 | 511 | def call_manager(params): 512 | bam=params['bam'] 513 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False) 514 | header_dict=bam_file.header.to_dict() 515 | 516 | print('%s: Getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True) 517 | 518 | ref_seq_dict={} 519 | ref_pos_dict={} 520 | 521 | labelled_pos_list={} 522 | 523 | if params['pos_list']: 524 | labelled_pos_list=get_pos(params['pos_list']) 525 | params['chrom']=[x for x in params['chrom'] if x in labelled_pos_list.keys()] 526 | 527 | motif_label=params['motif_label'] 528 | _=get_ref_to_num('ACGT') 529 | ref_seq_dict={} 530 | 531 | with mp.Pool(processes=params['threads']) as pool: 532 | res=pool.map(get_ref_info, zip(repeat(params), params['chrom'])) 533 | for r in res: 534 | chrom, seq_array, fwd_pos_array, rev_pos_array=r 535 | ref_seq_dict[chrom]=seq_array 536 | 537 | if params['pos_list']: 538 | ref_pos_dict[chrom]=(np.array(sorted(list(labelled_pos_list[chrom][0].keys()))).astype(int), np.array(sorted(list(labelled_pos_list[chrom][1].keys()))).astype(int)) 539 | 540 | else: 541 | ref_pos_dict[chrom]=(fwd_pos_array, rev_pos_array) 542 | labelled_pos_list[chrom]={0:{}, 1:{}} 543 | for strand in [0,1]: 544 | for pos in ref_pos_dict[chrom][strand]: 545 | labelled_pos_list[chrom][strand][pos]=float(motif_label) 546 | 547 | 548 | 549 | print('%s: Finished getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True) 550 | 551 | pmanager = mp.Manager() 552 | signal_Q = pmanager.Queue() 553 | output_Q = pmanager.Queue() 554 | process_event=pmanager.Event() 555 | input_event=pmanager.Event() 556 | 557 | handlers = [] 558 | 559 | input_process = mp.Process(target=get_input, args=(params, signal_Q, output_Q, input_event)) 560 | input_process.start() 561 | handlers.append(input_process) 562 | 563 | for hid in range(max(1,params['threads']-1)): 564 | p = mp.Process(target=process, args=(params, ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict, labelled_pos_list)); 565 | p.start(); 566 | handlers.append(p); 567 | 568 | output_process=mp.Process(target=get_output, args=(params, output_Q, process_event)); 569 | output_process.start(); 570 | 571 | for job in handlers: 572 | job.join() 573 | 574 | process_event.set() 575 | output_process.join() 576 | 577 | return 578 | 579 | if __name__ == '__main__': 580 | 581 | t=time.time() 582 | 583 | print('%s: Starting feature generation.' %str(datetime.datetime.now())) 584 | 585 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 586 | 587 | parser.add_argument("--bam", help='Path to bam file', type=str, required=True) 588 | parser.add_argument("--window", help='Number of bases before or after the base of interest to include in the model. Total number of bases included in teh model will be 2xwindow+1.', type=int, default=10) 589 | parser.add_argument("--prefix", help='Prefix for the output files',type=str, default='output') 590 | parser.add_argument("--input", help='Path to folder containing POD5 or FAST5 files. Files will be recusrviely searched.', type=str, required=True) 591 | 592 | parser.add_argument("--output", help='Path to folder where features will be stored', type=str, required=True) 593 | 594 | parser.add_argument("--threads", help='Number of processors to use',type=int, default=1) 595 | 596 | parser.add_argument("--div_threshold", help='Divergence Threshold.',type=float, default=0.25) 597 | 598 | parser.add_argument("--reads_per_chunk", help='reads_per_chunk',type=int, default=100000) 599 | 600 | parser.add_argument("--ref", help='Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used.', type=str) 601 | 602 | parser.add_argument("--pos_list", help='Text file containing a list of positions to generate features for. Use either --pos_list or --motif to specify how to choose loci for feature generation, but not both. The file should be whitespace separated with the following information on each line: chrom pos strand label. The position is 0-based reference coordinate, strand is + for forward and - for negative strand; label is 1 for mod, 0 for unmod).', type=str) 603 | 604 | parser.add_argument("--file_type", help='Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file.',choices=['fast5','pod5'], type=str,required=True) 605 | 606 | parser.add_argument("--seq_type", help='Specify DNA or direct RNA sequencing.',choices=['dna','rna'], type=str,required=True) 607 | 608 | parser.add_argument("--guppy_group", help='Name of the guppy basecall group',type=str, default='Basecall_1D_000') 609 | parser.add_argument("--chrom", nargs='*', help='A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used.') 610 | parser.add_argument("--length_cutoff", help='Minimum cutoff for read length',type=int, default=0) 611 | parser.add_argument("--fast5_move", help='Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the FAST5 files contains move table.', default=False, action='store_true') 612 | 613 | parser.add_argument("--motif", help='Motif for generating features followed by zero-based indices of nucleotides within the motif to generate features for. The motif and each index listed should be separated by whitespace, e.g. "--motif CGCG 0 2". Use either --pos_list or --motif to specify how to choose loci for feature generation, but not both. Features will be generated for all loci of the read that map to a reference sequence that matches the motif. Multiple indices can be specified but they should refer to the same nucleotide letter. If you use --motif, it is assumed that all loci have the same modification label and you need to specify the label using --motif_label.', nargs='*') 614 | 615 | parser.add_argument("--motif_label", help='Modification label for the motif. 0 is for unmodified and 1 is for modified.',type=int, choices=[0,1]) 616 | 617 | parser.add_argument("--norm_type", help='How to normalize read signal.',type=str, choices=['mad', 'standard'], default='mad') 618 | 619 | args = parser.parse_args() 620 | 621 | if not args.output: 622 | args.output=os.getcwd() 623 | 624 | os.makedirs(args.output, exist_ok=True) 625 | 626 | 627 | if args.chrom: 628 | chrom_list=args.chrom 629 | else: 630 | chrom_list=pysam.Samfile(args.bam).references 631 | 632 | 633 | if args.motif and len(args.motif)>0: 634 | if args.pos_list is not None: 635 | print('Use either --motif or --pos_list but not both', flush=True) 636 | sys.exit(3) 637 | 638 | if args.motif_label is None: 639 | print('--motif_label should be specified with --motif option', flush=True) 640 | sys.exit(3) 641 | 642 | motif_seq, exp_motif_seq, motif_ind, valid_motif=motif_check(args.motif) 643 | if not valid_motif: 644 | sys.exit(3) 645 | 646 | else: 647 | motif_seq=None 648 | motif_ind=None 649 | exp_motif_seq=None 650 | 651 | if args.pos_list is None: 652 | print('Use either --motif or --pos_list', flush=True) 653 | sys.exit(3) 654 | 655 | params=dict(bam=args.bam, 656 | seq_type=args.seq_type, 657 | window=args.window, 658 | pos_list=args.pos_list, 659 | ref=args.ref, 660 | input=args.input, 661 | norm_type=args.norm_type, 662 | motif_seq=motif_seq, 663 | motif_ind=motif_ind, 664 | exp_motif_seq=exp_motif_seq, 665 | motif_label=args.motif_label, 666 | file_type=args.file_type, 667 | guppy_group=args.guppy_group, 668 | fast5_move=args.fast5_move, 669 | chrom=chrom_list, threads=args.threads, 670 | length_cutoff=args.length_cutoff, 671 | output=args.output, prefix=args.prefix, 672 | div_threshold=args.div_threshold, reads_per_chunk=args.reads_per_chunk) 673 | 674 | print(args) 675 | 676 | 677 | with open(os.path.join(args.output,'args'),'w') as file: 678 | file.write('Command: python %s\n\n\n' %(' '.join(sys.argv))) 679 | file.write('------Parameters Used For Running DeepMod2------\n') 680 | for k in vars(args): 681 | file.write('{}: {}\n'.format(k,vars(args)[k]) ) 682 | call_manager(params) 683 | print('\n%s: Time elapsed=%.4fs' %(str(datetime.datetime.now()),time.time()-t)) 684 | -------------------------------------------------------------------------------- /train/train_models.py: -------------------------------------------------------------------------------- 1 | import math, time, argparse, re, os, sys 2 | import functools, itertools, random 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 7 | from torch import optim, Tensor 8 | import numpy as np 9 | import multiprocessing as mp 10 | import multiprocessing as mp 11 | import queue 12 | import pickle 13 | from utils import * 14 | 15 | class OneHotEncode(nn.Module): 16 | def __init__(self, num_classes: int): 17 | super().__init__() 18 | self.num_classes=num_classes 19 | def forward(self, x: Tensor) -> Tensor: 20 | return F.one_hot(x, self.num_classes) 21 | 22 | class RefandReadEmbed(nn.Module): 23 | def __init__(self, embedding_dim, embedding_type): 24 | super().__init__() 25 | 26 | self.embedding_depth=0 27 | 28 | if embedding_type=='one_hot': 29 | self.read_emb=OneHotEncode(4) 30 | self.embedding_depth+=4 31 | 32 | self.ref_emb=OneHotEncode(5) 33 | self.embedding_depth+=5 34 | 35 | elif embedding_type=='learnable': 36 | self.read_emb=nn.Embedding(4, embedding_dim) 37 | self.embedding_depth+=embedding_dim 38 | 39 | self.ref_emb=nn.Embedding(5, embedding_dim) 40 | self.embedding_depth+=embedding_dim 41 | 42 | def forward(self, batch_base_seq, batch_ref_seq): 43 | batch_base_seq_emb=self.read_emb(batch_base_seq) 44 | batch_ref_seq_emb=self.ref_emb(batch_ref_seq) 45 | 46 | return torch.cat((batch_base_seq_emb, batch_ref_seq_emb), 2) 47 | 48 | class ReadEmbed(nn.Module): 49 | def __init__(self, embedding_dim, embedding_type): 50 | super().__init__() 51 | 52 | self.embedding_depth=0 53 | 54 | if embedding_type=='one_hot': 55 | self.read_emb=OneHotEncode(4) 56 | self.embedding_depth+=4 57 | 58 | elif embedding_type=='learnable': 59 | self.read_emb=nn.Embedding(4, embedding_dim) 60 | self.embedding_depth+=embedding_dim 61 | 62 | 63 | def forward(self, batch_base_seq, batch_ref_seq): 64 | batch_base_seq_emb=self.read_emb(batch_base_seq) 65 | 66 | return batch_base_seq_emb 67 | 68 | class SeqEmbed(nn.Module): 69 | def __init__(self, embedding_dim, embedding_type, include_ref): 70 | super().__init__() 71 | 72 | self.embedding_depth=0 73 | 74 | if include_ref: 75 | self.seq_emb=RefandReadEmbed(embedding_dim, embedding_type) 76 | 77 | else: 78 | self.seq_emb=ReadEmbed(embedding_dim, embedding_type) 79 | 80 | self.embedding_depth=self.seq_emb.embedding_depth 81 | 82 | def forward(self, batch_base_seq, batch_ref_seq): 83 | return self.seq_emb(batch_base_seq, batch_ref_seq) 84 | 85 | class PositionalEncoding(nn.Module): 86 | def __init__(self, pe_dim: int, max_len: int): 87 | super().__init__() 88 | position = torch.arange(max_len).unsqueeze(1) 89 | div_term = torch.exp(torch.arange(0, pe_dim, 2) * (-math.log(pe_dim) / (pe_dim))) 90 | pe = torch.zeros(1, max_len, pe_dim) 91 | pe[0,:, 0::2] = torch.sin(position * div_term) 92 | pe[0,:, 1::2] = torch.cos(position * div_term) 93 | self.register_buffer('pe', pe) 94 | 95 | def forward(self, x: Tensor) -> Tensor: 96 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1)) 97 | x = torch.cat((x, x_pos),2) 98 | return x 99 | 100 | class PositionalEmbedding(nn.Module): 101 | def __init__(self, pe_dim: int, max_len: int): 102 | super().__init__() 103 | pos=torch.arange(max_len) 104 | self.register_buffer('pos', pos) 105 | self.pe=nn.Embedding(max_len, pe_dim) 106 | 107 | def forward(self, x: Tensor) -> Tensor: 108 | x_pos=self.pe(self.pos) 109 | x_pos=torch.Tensor.repeat(x_pos,(x.size(0),1,1)) 110 | x = torch.cat((x, x_pos),2) 111 | return x 112 | 113 | class PositionalParameter(nn.Module): 114 | def __init__(self, pe_dim: int, max_len: int): 115 | super().__init__() 116 | 117 | self.pe=torch.nn.Parameter(torch.randn(max_len, pe_dim)) 118 | 119 | def forward(self, x: Tensor) -> Tensor: 120 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1)) 121 | x = torch.cat((x, x_pos),2) 122 | return x 123 | 124 | class ClassifierMiddle(nn.Module): 125 | def __init__(self, in_dim: int, num_fc: int, model_len: int): 126 | super().__init__() 127 | self.mid = model_len//2 128 | self.fc = nn.Linear(in_dim, num_fc) 129 | self.out = nn.Linear(num_fc,1) 130 | 131 | def forward(self, x): 132 | x = F.relu(self.fc(x[:,self.mid, :])) 133 | x=self.out(x) 134 | return x 135 | 136 | class ClassifierAll(nn.Module): 137 | def __init__(self, in_dim: int, num_fc: int): 138 | super().__init__() 139 | self.fc = nn.Linear(in_dim, num_fc) 140 | self.out = nn.Linear(num_fc,1) 141 | 142 | def forward(self, x): 143 | x = torch.flatten(x, 1) 144 | x = F.relu(self.fc(x)) 145 | x=self.out(x) 146 | return x 147 | 148 | class BiLSTM(nn.Module): 149 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, fc_type): 150 | super(BiLSTM, self).__init__() 151 | 152 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref) 153 | self.model_len=model_dims[0] 154 | self.model_depth=model_dims[1]+self.emb.embedding_depth 155 | 156 | self.bilstm = nn.LSTM(input_size=self.model_depth, hidden_size=dim_feedforward, num_layers=num_layers, bidirectional=True, batch_first = True) 157 | 158 | if fc_type=='middle': 159 | self.classifier=ClassifierMiddle(in_dim=dim_feedforward*2, num_fc=num_fc, model_len=self.model_len) 160 | 161 | else: 162 | self.classifier=ClassifierAll(in_dim=self.model_len*dim_feedforward*2, num_fc=num_fc) 163 | 164 | def forward(self, batch_x, batch_base_seq, batch_ref_seq): 165 | seq_emb=self.emb(batch_base_seq, batch_ref_seq) 166 | x=torch.cat((batch_x, seq_emb), 2) 167 | x, _=self.bilstm(x) 168 | x = self.classifier(x) 169 | 170 | return x 171 | 172 | class TransformerModel(nn.Module): 173 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, pe_dim, nhead, pe_type, fc_type): 174 | super(TransformerModel, self).__init__() 175 | 176 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref) 177 | self.model_len=model_dims[0] 178 | 179 | if pe_type=='fixed': 180 | self.pe_block=PositionalEncoding(pe_dim=pe_dim, max_len=self.model_len) 181 | 182 | elif pe_type=='embedding': 183 | self.pe_block=PositionalEmbedding(pe_dim=pe_dim, max_len=self.model_len) 184 | 185 | elif pe_type=='parameter': 186 | self.pe_block=PositionalParameter(pe_dim=pe_dim, max_len=self.model_len) 187 | 188 | self.model_depth=model_dims[1]+self.emb.embedding_depth+pe_dim 189 | self.pad_length=math.ceil(self.model_depth/nhead)*nhead-self.model_depth 190 | pad=torch.zeros(1,self.model_len, self.pad_length) 191 | self.register_buffer('pad', pad) 192 | self.model_depth+=self.pad_length 193 | 194 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_depth, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0, batch_first = True) 195 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) 196 | 197 | if fc_type=='middle': 198 | self.classifier=ClassifierMiddle(in_dim=self.model_depth, num_fc=num_fc, model_len=self.model_len) 199 | 200 | else: 201 | self.classifier=ClassifierAll(in_dim=self.model_len*self.model_depth, num_fc=num_fc) 202 | 203 | def forward(self, batch_x, batch_base_seq, batch_ref_seq): 204 | seq_emb=self.emb(batch_base_seq, batch_ref_seq) 205 | x=torch.cat((batch_x, seq_emb), 2) 206 | x=self.pe_block(x) 207 | x_pad=torch.Tensor.repeat(self.pad,(x.size(0),1,1)) 208 | x = torch.cat((x, x_pad),2) 209 | 210 | x=self.transformer_encoder(x) 211 | x = self.classifier(x) 212 | 213 | return x 214 | 215 | 216 | 217 | def train(training_dataset, validation_dataset, validation_type, validation_fraction, model_config, epochs, prefix, retrain, batch_size, args_str, seed): 218 | print('Starting training.' , flush=True) 219 | torch.manual_seed(seed) 220 | model_type = model_config['model_type'] 221 | model_save_path = model_config.pop('model_save_path') 222 | 223 | if torch.cuda.is_available(): 224 | dev = "cuda:0" 225 | else: 226 | dev = "cpu" 227 | 228 | weight_counts=np.array([np.sum(np.eye(2)[np.load(f)['label'].astype(int)],axis=0) for f in itertools.chain.from_iterable(training_dataset)]) 229 | weight_counts=np.sum(weight_counts,axis=0) 230 | 231 | 232 | if model_config['weights']=='equal': 233 | pos_weight=torch.Tensor(np.array(1.0)) 234 | 235 | elif model_config['weights']=='auto': 236 | pos_weight=torch.Tensor(np.array(weight_counts[0]/weight_counts[1])) 237 | 238 | else: 239 | pos_weight=torch.Tensor(np.array(float(model_config['weights']))) 240 | 241 | print('Number of Modified Instances={}\nNumber of Un-Modified Instances={}\nPositive Label Weight={}\n'.format(weight_counts[1],weight_counts[0],pos_weight), flush=True) 242 | 243 | if model_type=='bilstm': 244 | net = BiLSTM(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \ 245 | dim_feedforward=model_config['dim_feedforward'], \ 246 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \ 247 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'], \ 248 | fc_type=model_config['fc_type']); 249 | 250 | elif model_type=='transformer': 251 | net = TransformerModel(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \ 252 | dim_feedforward=model_config['dim_feedforward'], \ 253 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \ 254 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'],\ 255 | pe_dim=model_config['pe_dim'], nhead=model_config['nhead'], \ 256 | pe_type=model_config['pe_type'], fc_type=model_config['fc_type']); 257 | 258 | net.to(dev); 259 | 260 | optimizer = optim.Adam(net.parameters(), lr= model_config['lr'], weight_decay=model_config['l2_coef']) 261 | 262 | if retrain: 263 | checkpoint = torch.load(retrain) 264 | net.load_state_dict(checkpoint['model_state_dict']) 265 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 266 | 267 | 268 | model_details=str(net.to(torch.device(dev))) 269 | print(model_details, flush=True) 270 | num_params=sum(p.numel() for p in net.parameters()) 271 | print('# Parameters=', num_params, flush=True) 272 | 273 | config_path=os.path.join(model_save_path, '%s.cfg' %prefix) 274 | 275 | with open(config_path, 'wb') as handle: 276 | pickle.dump(model_config, handle, protocol=pickle.HIGHEST_PROTOCOL) 277 | 278 | log_file_path=os.path.join(model_save_path, '%s.log' %prefix) 279 | 280 | train_w_wo_ref=model_config['train_w_wo_ref'] 281 | include_ref=model_config['include_ref'] 282 | 283 | if include_ref and train_w_wo_ref: 284 | list_of_evals=['Normal', 'Without_Ref'] 285 | else: 286 | list_of_evals=['Normal'] 287 | 288 | dummy_ref_seq=(4+torch.zeros(batch_size, model_config["model_dims"][0])).type(torch.LongTensor).to(dev) 289 | 290 | with open(log_file_path,'w') as log_file: 291 | log_file.write(args_str) 292 | log_file.write('\n# Parameters=%d\n' %num_params) 293 | log_file.write(model_details) 294 | 295 | for j in range(epochs): 296 | net.train() 297 | 298 | metrics_train={g:{'TP':0,'FP':0,'FN':0,'loss':0,'len':0,'true':0} for g in list_of_evals} 299 | metrics_test={g:{'TP':0,'FP':0,'FN':0,'loss':0,'len':0,'true':0} for g in list_of_evals} 300 | 301 | t=time.time() 302 | 303 | train_gen=generate_batches_mixed_can_mod(training_dataset, validation_type, validation_fraction, data_type="train", batch_size=batch_size) 304 | 305 | for batch in train_gen: 306 | batch_x, batch_base_seq, batch_ref_seq, batch_y =batch 307 | batch_x, batch_base_seq, batch_ref_seq, batch_y=batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev), batch_y.to(dev) 308 | 309 | optimizer.zero_grad() 310 | score= net(batch_x, batch_base_seq, batch_ref_seq) 311 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight) 312 | 313 | loss.backward() 314 | optimizer.step() 315 | 316 | get_metrics(metrics_train,'Normal', batch_y, score, loss) 317 | 318 | if include_ref and train_w_wo_ref: 319 | dummy_batch_ref_seq=dummy_ref_seq[:batch_ref_seq.size(0)] 320 | optimizer.zero_grad() 321 | score= net(batch_x, batch_base_seq, dummy_batch_ref_seq) 322 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight) 323 | 324 | loss.backward() 325 | optimizer.step() 326 | get_metrics(metrics_train,'Without_Ref', batch_y, score, loss) 327 | 328 | with torch.no_grad(): 329 | net.eval() 330 | 331 | if validation_type=='split': 332 | test_gen=generate_batches(list(itertools.chain.from_iterable(training_dataset)), validation_type, validation_fraction, data_type="test", batch_size=batch_size) 333 | 334 | else: 335 | test_gen=generate_batches(validation_dataset, validation_type, validation_fraction, data_type="test", batch_size=batch_size) 336 | 337 | for batch in test_gen: 338 | batch_x, batch_base_seq, batch_ref_seq, batch_y = batch 339 | batch_x, batch_base_seq, batch_ref_seq, batch_y=batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev), batch_y.to(dev) 340 | 341 | score= net(batch_x, batch_base_seq, batch_ref_seq) 342 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight) 343 | 344 | get_metrics(metrics_test,'Normal', batch_y, score, loss) 345 | 346 | if include_ref and train_w_wo_ref: 347 | dummy_batch_ref_seq=dummy_ref_seq[:batch_ref_seq.size(0)] 348 | score= net(batch_x, batch_base_seq, dummy_batch_ref_seq) 349 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight) 350 | 351 | get_metrics(metrics_test,'Without_Ref', batch_y, score, loss) 352 | 353 | train_str, _ = get_stats(metrics_train, 'Training') 354 | test_str, total_test_acc = get_stats(metrics_test, 'Testing') 355 | 356 | epoch_log='\n\nEpoch %d: #Train=%d #Test=%d Time=%.4f\n%s\n\n%s'\ 357 | %(j+1, sum(x['len'] for x in metrics_train.values()), sum(x['len'] for x in metrics_test.values()), time.time()-t, 358 | train_str, test_str) 359 | print(epoch_log, flush=True) 360 | log_file.write(epoch_log) 361 | log_file.flush() 362 | os.fsync(log_file.fileno()) 363 | 364 | model_path=os.path.join(model_save_path, 'model.epoch%d.%.4f' %(j+1, total_test_acc)) 365 | torch.save({ 366 | 'epoch': j+1, 367 | 'model_state_dict': net.state_dict(), 368 | 'optimizer_state_dict': optimizer.state_dict()}, model_path) 369 | 370 | return net 371 | 372 | if __name__=='__main__': 373 | start_time=time.time() 374 | 375 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 376 | 377 | parser.add_argument("--mixed_training_dataset", nargs='*', help='Training dataset with mixed labels. A whitespace separated list of folders containing .npz files or paths to individual .npz files.') 378 | 379 | parser.add_argument("--can_training_dataset", nargs='*', help='Training dataset with unmodified or canonical base labels. A whitespace separated list of folders containing .npz files or paths to individual .npz files.') 380 | parser.add_argument("--mod_training_dataset", nargs='*', help='Training dataset with modified labels. A whitespace separated list of folders containing .npz files or paths to individual .npz files.') 381 | 382 | parser.add_argument("--validation_type",choices=['split','dataset'], help='How the validation is performed. "split" means that a fraction of training dataset, specified by --validation_fraction, will be used for validation. "dataset" means that additional validation dataset is provided via --validation_input parameter.',default="split") 383 | parser.add_argument("--validation_fraction", help='Fraction of training dataset to use for validation if --validation_type is set to "split", otherwise ignored.', type=float, default=0.2) 384 | parser.add_argument("--validation_dataset",nargs='*', help='Validation dataset if --validation_type is set to dataset. A whitespace separated list of folders containing .npz files or paths to individual .npz files.') 385 | 386 | parser.add_argument("--prefix", help='Prefix name for the model checkpoints', default='model') 387 | 388 | parser.add_argument("--weights", help='Weight of positive(modified) label used in binary cross entropy loss, the negative(unmodified) label will always have a fixed weight of 1. Higher weight for modified labels will favor recall and lower weight will favor precision. Choices are "equal", "auto" or your can specify the weight of positive(modified) labels. "equal" assigns a weight of 1 to modified labels, "auto" assigns a weight=num_negative_samples/num_positive_samples to modified labels.', default='equal') 389 | 390 | 391 | parser.add_argument("--model_save_path", help='Folder path for saving model checkpoints') 392 | parser.add_argument("--epochs", help='Number of total epochs', default=100, type=int) 393 | parser.add_argument("--batch_size", help='Batch Size', default=256, type=int) 394 | parser.add_argument("--retrain", help='Path to a model for retraining', default=None) 395 | 396 | parser.add_argument("--fc_type", help='Type of full connection to use in the classifier.', type=str, default='all', choices=['middle', 'all']) 397 | parser.add_argument("--model_type", help='Type of model to use', type=str, choices=['bilstm', 'transformer']) 398 | 399 | parser.add_argument("--num_layers", help='Number of transformer encoder or BiLSTM layers', type=int, default=3) 400 | parser.add_argument("--dim_feedforward", help='Dimension of feedforward layers in transformer encoder or size of hidden units in BiLSTM layers', type=int, default=100) 401 | parser.add_argument("--num_fc", help='Size of fully connected layer between encoder/BiLSTM and classifier', type=int, default=16) 402 | parser.add_argument("--embedding_dim", help='Size of embedding dimension for read and reference bases', type=int, default=4) 403 | parser.add_argument("--embedding_type", help='Type of embedding for bases', type=str, choices=['learnable', 'one_hot'], default='one_hot') 404 | parser.add_argument("--pe_dim", help='Dimension for positional encoding/embedding for transformer model.', type=int, default=16) 405 | parser.add_argument("--pe_type", help='Type of positional encoding/embedding for transformer model. fixed is sinusoid, embedding is is dictionary lookup, parameter weight matrix.', type=str, choices=['fixed', 'embedding', 'parameter'], default='fixed') 406 | parser.add_argument("--nhead", help='Number of self-attention heads in transformer encoder for transformer model.', type=int, default=4) 407 | parser.add_argument("--include_ref", help='Whether to include reference sequence as features. Recommended.', default=False, action='store_true') 408 | parser.add_argument("--train_w_wo_ref", help='Include examples with reference and without reference sequence. Recommended if you will be using referenve free modification detection.', default=False, action='store_true') 409 | 410 | parser.add_argument("--lr", help='Learning rate', type=float, default=1e-4) 411 | parser.add_argument("--l2_coef", help='L2 regularization coefficient', type=float, default=1e-5) 412 | parser.add_argument("--seed", help='Random seed to use in pytorch for reproducibility or reinitialization of weights', default=None) 413 | 414 | args = parser.parse_args() 415 | 416 | os.makedirs(args.model_save_path, exist_ok=True) 417 | 418 | mixed_training_dataset=get_files(args.mixed_training_dataset) 419 | can_training_dataset=get_files(args.can_training_dataset) 420 | mod_training_dataset=get_files(args.mod_training_dataset) 421 | 422 | validation_dataset=get_files(args.validation_dataset) 423 | validation_type=args.validation_type 424 | validation_fraction=args.validation_fraction 425 | 426 | valid_data, window, norm_type, strides_per_base, model_depth, full_signal = check_training_files(mixed_training_dataset, can_training_dataset,\ 427 | mod_training_dataset, validation_dataset) 428 | 429 | if not valid_data: 430 | sys.exit(3) 431 | 432 | model_len=strides_per_base*(2*window+1)#2*window+1 433 | 434 | model_config = dict(model_dims=(model_len,model_depth+1),window=window, model_type=args.model_type, 435 | num_layers=args.num_layers, dim_feedforward=args.dim_feedforward, 436 | num_fc=args.num_fc, embedding_dim=args.embedding_dim, 437 | embedding_type=args.embedding_type, include_ref=args.include_ref, 438 | pe_dim=args.pe_dim, nhead=args.nhead, pe_type=args.pe_type, 439 | l2_coef=args.l2_coef, lr=args.lr, model_save_path=args.model_save_path, fc_type=args.fc_type, 440 | train_w_wo_ref=args.train_w_wo_ref, weights=args.weights, norm_type=norm_type, full_signal=full_signal, strides_per_base=strides_per_base) 441 | 442 | args_dict=vars(args) 443 | args_str=''.join('%s: %s\n' %(k,str(v)) for k,v in args_dict.items()) 444 | print(args_str, flush=True) 445 | 446 | seed =random.randint(0, 0xffff_ffff_ffff_ffff) if args.seed is None else int(args.seed) 447 | 448 | training_dataset = [mixed_training_dataset, can_training_dataset, mod_training_dataset] 449 | 450 | with open(os.path.join(args.model_save_path,'args'),'w') as file: 451 | file.write('Command: python %s\n\n\n' %(' '.join(sys.argv))) 452 | file.write('------Parameters Used For Running DeepMod2------\n') 453 | for k in vars(args): 454 | file.write('{}: {}\n'.format(k,vars(args)[k]) ) 455 | 456 | res=train(training_dataset, validation_dataset, validation_type, validation_fraction, model_config, epochs=args.epochs,prefix=args.prefix, retrain=args.retrain, batch_size=args.batch_size, args_str=args_str, seed=seed) 457 | 458 | print('Time taken=%.4f' %(time.time()-start_time), flush=True) 459 | -------------------------------------------------------------------------------- /train/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch, os, random 3 | from pathlib import Path 4 | from sklearn.model_selection import train_test_split 5 | import sklearn 6 | import itertools 7 | 8 | def get_files(input_list): 9 | out_list=[] 10 | if type(input_list)==type(None): 11 | return out_list 12 | for item in input_list: 13 | 14 | if os.path.isdir(item): 15 | out_list.extend(list(Path(item).rglob("*.npz"))) 16 | 17 | elif item[-4:]=='.npz': 18 | out_list.append(item) 19 | 20 | random.seed(0) 21 | random.shuffle(out_list) 22 | return out_list 23 | 24 | def read_from_file(file_name, validation_type, validation_fraction, data_type, get_pos=False): 25 | data=np.load(file_name) 26 | mat=data['mat'] 27 | base_qual=data['base_qual'] 28 | 29 | features=np.dstack((mat, base_qual[:,:,np.newaxis])) 30 | features=torch.Tensor(features) 31 | 32 | base_seq=torch.Tensor(data['base_seq']).type(torch.LongTensor) 33 | ref_seq=torch.Tensor(data['ref_seq']).type(torch.LongTensor) 34 | 35 | labels=torch.Tensor(data['label'][:,np.newaxis]) 36 | 37 | if get_pos: 38 | pos=data['ref_coordinates'] 39 | chrom=data['ref_name'] 40 | pos_data=np.vstack([chrom, pos]).T 41 | 42 | return features, base_seq, ref_seq, labels, pos_data 43 | 44 | if validation_type=='split': 45 | features_train, features_test, base_seq_train, base_seq_test, \ 46 | ref_seq_train, ref_seq_test, labels_train, labels_test=\ 47 | train_test_split(features, base_seq, ref_seq, labels, test_size=validation_fraction, random_state=42) 48 | if data_type=='train': 49 | return features_train, base_seq_train, ref_seq_train, labels_train 50 | else: 51 | return features_test, base_seq_test, ref_seq_test, labels_test 52 | 53 | else: 54 | return features, base_seq, ref_seq, labels 55 | 56 | def check_training_files(mixed_training_dataset, can_training_dataset,\ 57 | mod_training_dataset, validation_dataset): 58 | norm_type=[str(np.load(file)['norm_type']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\ 59 | mod_training_dataset, validation_dataset])] 60 | window=[int(np.load(file)['window']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\ 61 | mod_training_dataset, validation_dataset])] 62 | strides_per_base=[int(np.load(file)['strides_per_base']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\ 63 | mod_training_dataset, validation_dataset])] 64 | 65 | model_depth=[int(np.load(file)['model_depth']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\ 66 | mod_training_dataset, validation_dataset])] 67 | full_signal=np.unique([np.load(file)['full_signal'] for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset, mod_training_dataset, validation_dataset])]) 68 | 69 | if len(set(window))==1 and len(set(norm_type))==1 and len(set(strides_per_base))==1 and len(full_signal)==1: 70 | return True, window[0], norm_type[0], strides_per_base[0], model_depth[0], full_signal[0] 71 | 72 | elif len(set(window))>1: 73 | print('Inconsistent dataset with multiple window sizes') 74 | 75 | elif len(set(norm_type))>1: 76 | print('Inconsistent dataset with multiple normalization types') 77 | 78 | elif len(set(strides_per_base))>1: 79 | print('Inconsistent dataset with multiple strides_per_base') 80 | 81 | return False, window, norm_type, strides_per_base, model_depth, full_signal 82 | 83 | def generate_batches(files, validation_type, validation_fraction, data_type, batch_size=512): 84 | counter = 0 85 | 86 | print_freq=max(1, len(files)//10) 87 | 88 | while counter0]) 122 | batch_base_seq=torch.vstack([d[1] for d in multi_batch_data if len(d[1])>0]) 123 | batch_ref_seq=torch.vstack([d[2] for d in multi_batch_data if len(d[2])>0]) 124 | batch_y=torch.vstack([d[3] for d in multi_batch_data if len(d[3])>0]) 125 | 126 | yield batch_x, batch_base_seq, batch_ref_seq, batch_y 127 | 128 | def get_stats(metrics_dict, dtype): 129 | 130 | loss_str='{} Loss:'.format(dtype) 131 | acc_str='{} Accuracy:'.format(dtype) 132 | prec_str='{} Precision:'.format(dtype) 133 | rec_str='{} Recall:'.format(dtype) 134 | f1_str='{} F1:'.format(dtype) 135 | 136 | for g in sorted(metrics_dict.keys()): 137 | 138 | acc=metrics_dict[g]['true']/max(1,metrics_dict[g]['len']) 139 | loss=metrics_dict[g]['loss']/max(1,metrics_dict[g]['len']) 140 | precision=metrics_dict[g]['TP']/max(1,metrics_dict[g]['TP']+metrics_dict[g]['FP']) 141 | recall=metrics_dict[g]['TP']/max(1,metrics_dict[g]['TP']+metrics_dict[g]['FN']) 142 | f1=2*precision*recall/(precision+recall) if precision*recall!=0 else 0 143 | 144 | if len(metrics_dict.keys())==1: 145 | x='Total' 146 | total_acc=acc 147 | else: 148 | x=g 149 | 150 | loss_str+=' %s: %.4f' %(x, loss) 151 | acc_str+=' %s: %.4f' %(x, acc) 152 | prec_str+=' %s: %.4f' %(x, precision) 153 | rec_str+=' %s: %.4f' %(x, recall) 154 | f1_str+=' %s: %.4f' %(x, f1) 155 | 156 | if len(metrics_dict.keys())>1: 157 | x='Total' 158 | 159 | acc=sum(f['true'] for f in metrics_dict.values())/max(1,sum(f['len'] for f in metrics_dict.values())) 160 | loss=sum(f['loss'] for f in metrics_dict.values())/max(1,sum(f['len'] for f in metrics_dict.values())) 161 | precision=sum(f['TP'] for f in metrics_dict.values())/max(1,sum(f['TP'] for f in metrics_dict.values()) +sum(f['FP'] for f in metrics_dict.values())) 162 | recall=sum(f['TP'] for f in metrics_dict.values())/max(1,sum(f['TP'] for f in metrics_dict.values())+sum(f['FN'] for f in metrics_dict.values())) 163 | f1=2*precision*recall/(precision+recall) if precision*recall!=0 else 0 164 | 165 | total_acc=acc 166 | 167 | loss_str+=' %s: %.4f' %(x, loss) 168 | acc_str+=' %s: %.4f' %(x, acc) 169 | prec_str+=' %s: %.4f' %(x, precision) 170 | rec_str+=' %s: %.4f' %(x, recall) 171 | f1_str+=' %s: %.4f' %(x, f1) 172 | 173 | return '\n'.join([loss_str, acc_str,prec_str, rec_str, f1_str]), total_acc 174 | 175 | def get_metrics(metrics_dict,name, batch_y, score, loss): 176 | eval_counts=sklearn.metrics.confusion_matrix(batch_y.cpu(),(score>0).cpu(),labels=[0,1]) 177 | metrics_dict[name]['len']+=len(batch_y) 178 | metrics_dict[name]['TP']+=eval_counts[1,1] 179 | metrics_dict[name]['FP']+=eval_counts[0,1] 180 | metrics_dict[name]['FN']+=eval_counts[1,0] 181 | metrics_dict[name]['true']=metrics_dict[name]['true']+eval_counts[0,0]+eval_counts[1,1] 182 | metrics_dict[name]['loss']+=loss.item()*len(batch_y) 183 | --------------------------------------------------------------------------------