├── LICENSE
├── README.md
├── deepmod2
├── docs
├── Example.md
├── Install.md
├── Training.md
└── Usage.md
├── environment.yml
├── plot_utils
├── README.md
├── Signal_Plot_Examples.ipynb
└── plot.py
├── src
├── detect.py
├── models.py
├── models
│ ├── bilstm.cfg
│ ├── bilstm
│ │ ├── R10.4.1_4kHz_v3.5
│ │ ├── R10.4.1_4kHz_v4.1
│ │ ├── R10.4.1_5kHz_v4.3
│ │ ├── R10.4.1_5kHz_v5.0
│ │ └── R9.4.1
│ ├── transformer.cfg
│ └── transformer
│ │ ├── R10.4.1_4kHz_v3.5
│ │ ├── R10.4.1_4kHz_v4.1
│ │ ├── R10.4.1_5kHz_v4.3
│ │ ├── R10.4.1_5kHz_v5.0
│ │ └── R9.4.1
└── utils.py
└── train
├── generate_features.py
├── train_models.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Wang Genomics Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepMod2
2 | DeepMod2 is a computational tool for detecting DNA 5mC methylation from Oxford Nanopore reads. It uses a BiLSTM model to predict per-read and per-site 5mC methylations for CpG sites and produces a methylation tagged BAM file. DeepMod2 can call methylation from POD5 and FAST5 files basecalled with Guppy or Dorado and provides models for R10.4.1 and R9.4.1 flowcells.
3 |
4 |
5 |
6 | DeepMod2 is distributed under the [MIT License by Wang Genomics Lab](https://wglab.mit-license.org/).
7 |
8 | ### Citing DeepMod2
9 | Ahsan, M.U., Gouru, A., Chan, J. et al. A signal processing and deep learning framework for methylation detection using Oxford Nanopore sequencing. Nat Commun 15, 1448 (2024). https://doi.org/10.1038/s41467-024-45778-y
10 | ## Installation
11 | Please refer to [Installation](https://github.com/WGLab/DeepMod2/blob/main/docs/Install.md) for how to install DeepMod2.
12 |
13 | ## Inference
14 | Quick usage guide for model inference:
15 | 1. Basecall your FAST5/POD5 files with Dorado (using `--emit-moves`) or Guppy (using `--bam_out --moves_out`) parameters to get a BAM file with move tables:
16 | ```
17 | dorado basecaller MODEL INPUT_DIR --emit-moves > basecall.bam
18 | ```
19 | Make sure to use the appropriate Guppy/Dorado model for your sequencing kit. You can supply a reference genome to Guppy/Dorado to get aligned BAM files, or use minimap2 to align these BAM files later.
20 | 2. (Optional but recommended) Align basecalled reads to a reference genome while retaining the move tables:
21 | ```
22 | samtools fastq basecall.bam -T mv,ts | minimap2 -ax map-ont ref.fa - -y -t NUM_THREADS |samtools view -o aligned.bam
23 | ```
24 | 3. Run DeepMod2 by providing BAM file and the folder containing FAST5 or POD5 signal files as inputs. You can provide reference FASTA file to get reference anchored methylation calls and per-site frequencies if the BAM file is aligned. Specify the model you want to use and the file type of raw signal files. Use multiple cores and/or GPUs for speedup.
25 |
26 | a) If using an aligned BAM file input:
27 | ```
28 | python PATH_TO_DEEPMOD2_REPOSITORY/deepmod2 detect --bam reads.bam --input INPUT_DIR --model MODEL --file_type FILE_TYPE --threads NUM_THREADS --ref ref.fa --output MOD_CALLS
29 | ```
30 |
31 | b) If using an unaligned BAM file input:
32 | ```
33 | python PATH_TO_DEEPMOD2_REPOSITORY/deepmod2 detect --bam reads.bam --input INPUT_DIR --model MODEL --file_type FILE_TYPE --threads NUM_THREADS --output MOD_CALLS
34 | ```
35 | This will give you a per-read prediction text file `MOD_CALLS/output.per_read`, a per-site prediction file `MOD_CALLS/output.per_site`, a per-site prediction file with both strands aggregated `MOD_CALLS/output.per_site.aggregated`, and a methylation annotated BAM file `MOD_CALLS/output.bam`.
36 | 5. Visualize the annotated BAM file produced by DeepMod2 in IGV file. In IGV, select 'Color alignments by' and 'base modifications (5mC)'. The following steps will allow you to open the tagged BAM file in IGV:
37 |
38 | a) If an aligned BAM is given to DeepMod2, you only need to sort and index the DeepMod2 methylation tagged BAM file:
39 | ```
40 | samtools sort MOD_CALLS/output.bam -o MOD_CALLS/final.bam --write-index
41 | ```
42 |
43 | b) If an unaligned BAM is given to DeepMod2, first align the DeepMod2 methylation tagged BAM file (while preserving methylation tags MM and ML), then sort and index it:
44 | ```
45 | samtools fastq MOD_CALLS/output.bam -T MM,ML,mv,ts| minimap2 -ax map-ont ref.fa - -y -t NUM_THREADS |samtools sort -o MOD_CALLS/final.bam --write-index
46 | ```
47 |
48 |
49 |
50 |
51 | Please refer to [Usage.md](docs/Usage.md) for details on how to use DeepMod2.
52 |
53 | ## Training
54 | For a detailed usage guide for model training, refer to [Training.md](docs/Training.md) and the code under [train](train/).
55 |
56 | ## Signal Plotting
57 | For a detailed usage guide for signal plotting and comparison, refer to the Jupyter notebook [Signal_Plot_Examples.ipynb](plot_utils/Signal_Plot_Examples.ipynb) and the code under [plot_utils](plot_utils/). The code works with signal files in POD5 format and aligned BAM files with move tables. The Jupyter notebook uses the following data: [plot_files.tar.gz](https://github.com/WGLab/DeepMod2/files/14985308/plot_files.tar.gz).
58 |
59 | ## Models
60 | The following models for 5mC detection in CpG motif are provided in the repository. Use `--model MODEL_NAME` to specify a model to use. You only need to provide the name of the model, not the path to it. Each model is compatible with a different Dorado or Guppy basecalling model version.
61 | |Model Architecture|DeepMod2 Model Name|Flowcell
(Sampling Rate)| Compatible Dorado/Guppy Basecalling Model|
62 | |-|-|-|-|
63 | |BiLSTM|**bilstm_r10.4.1_5khz_v5.0**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v5.0.0**
64 | |Transformer|**transformer_r10.4.1_5khz_v5.0**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v5.0.0**
65 | |BiLSTM|**bilstm_r10.4.1_5khz_v4.3**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.3.0**
66 | |Transformer|**transformer_r10.4.1_5khz_v4.3**|R10.4.1 (5kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.3.0**
67 | |BiLSTM|**bilstm_r10.4.1_4khz_v4.1**|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.1.0** in Dorado
**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup).cfg** in Guppy 6.5.7|
68 | |Transformer|**transformer_r10.4.1_4khz_v4.1**|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v4.1.0** in Dorado
**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup).cfg** in Guppy 6.5.7|
69 | |BiLSTM|**bilstm_r10.4.1_4khz_v3.5**
(Published in DeepMod2 paper)|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v3.5.2** in Dorado
**dna_r10.4.1_e8.2_400bps_fast\|hac\|sup).cfg** in Guppy 6.3.8|
70 | |Transformer|**transformer_r10.4.1_4khz_v3.5**
(Published in DeepMod2 paper)|R10.4.1 (4kHz)|**dna_r10.4.1_e8.2_400bps_(fast\|hac\|sup)@v3.5.2** in Dorado
**dna_r10.4.1_e8.2_400bps_fast\|hac\|sup).cfg** in Guppy 6.3.8|
71 | |BiLSTM|**bilstm_r9.4.1**
(Published in DeepMod2 paper)|R9.4.1 (4kHz)|**dna_r9.4.1_e8_(hac\|sup)@v3.3** in Dorado
**dna_r9.4.1_450bps_(hac\|sup).cfg** in Guppy 6.3.8 and 6.5.7|
72 | |Transformer|**transformer_r9.4.1**
(Published in DeepMod2 paper)|R9.4.1 (4kHz)|**dna_r9.4.1_e8_(hac\|sup)@v3.3** in Dorado
**dna_r9.4.1_450bps_(hac\|sup).cfg** in Guppy 6.3.8 and 6.5.7|
73 |
74 | ## Examples
75 |
76 | Please refer to [Example](https://github.com/WGLab/DeepMod2/blob/main/docs/Example.md) for a complete tutorial on how to run DeepMod2 under various scenarios. A test dataset consisting of a small example of ~60 reads in both POD5 and FAST5 format as well as the expected results in this release:https://github.com/WGLab/DeepMod2/releases/tag/v0.3.0.
77 |
--------------------------------------------------------------------------------
/deepmod2:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import time, itertools, torch
4 |
5 | import datetime, os, shutil, argparse, sys, pysam
6 |
7 | from src import utils
8 |
9 | if __name__ == '__main__':
10 |
11 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12 | parser.add_argument("--print_models", help='Print details of models available', default=False, action='store_true')
13 | main_subparsers = parser.add_subparsers(title="Options", dest="option")
14 |
15 | parent_parser = argparse.ArgumentParser(add_help=False,)
16 | parent_parser.add_argument("--prefix", help='Prefix for the output files',type=str, default='output')
17 | parent_parser.add_argument("--output", help= 'Path to folder where intermediate and final files will be stored, default is current working directory', type=str)
18 |
19 |
20 |
21 | parent_parser.add_argument("--qscore_cutoff", help='Minimum cutoff for mean quality score of a read',type=float, default=0)
22 | parent_parser.add_argument("--length_cutoff", help='Minimum cutoff for read length',type=int, default=0)
23 |
24 | parent_parser.add_argument("--mod_t", help= 'Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for calculation of per-site modification levels.', default=0.5, type=float)
25 |
26 | parent_parser.add_argument("--unmod_t", help= 'Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for calculation of per-site modification levels.', default=0.5, type=float)
27 |
28 | parent_parser.add_argument("--include_non_cpg_ref", help='Include non-CpG reference loci in per-site output where reads have CpG motif.',default=False, action='store_true')
29 |
30 | detect_parser = main_subparsers.add_parser("detect", parents=[parent_parser],
31 | add_help=True,
32 | help="Call methylation from Guppy or Dorado basecalled POD5/FAST5 files using move tables for signal alignment.", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
33 |
34 | detect_required=detect_parser.add_argument_group("Required Arguments")
35 |
36 | detect_parser.add_argument("--motif", help='Motif for detecting modifications followed by zero-based indices of nucleotides within the motif to call modification on. Default is CpG methylation "CG 0". Modification will be called for motif matches on the read and reference unless --reference_motif_only is used to restrict to reference motif matches only. Multiple indices can be specified but they should refer to the same nucleotide letter. The motif and each index listed should be separated by whitespace, e.g. "--motif CGCG 0 2"', nargs='*', default=['CG','0'])
37 |
38 | detect_parser.add_argument("--mod_symbol", help='Symbol to use for modified base in BAM tag MM. Default is to use "m" for 5mC in CG motif, and for other motifs the default is the canonical nucleotide symbol.' , type=str)
39 |
40 | detect_parser.add_argument("--mod_positions", help='A tab separated list of reference coordinates to call modification on. Modifications will only be called on reference positions specified that overlap with the motif, and no modification detection will be performed on other loci. The file shoule have the following format: "contig position strand" on each line. Position should be zero-based and strand should be "+" or "-".', default=None)
41 |
42 | detect_parser.add_argument("--reference_motif_only", help='Restrict modification calling to reference motif matches only', default=False, action='store_true')
43 |
44 | detect_parser.add_argument("--seq_type", help='Specify DNA or direct RNA sequencing.',choices=['dna','rna'], type=str,required=True)
45 |
46 | detect_parser.add_argument("--threads", help='Number of threads to use for processing signal and running model inference. If a GPU is used for inference, then --threads number of threads will be running on GPU concurrently. The total number of threads used by DeepMod2 is equal to --threads plus --bam_threads. It is recommended to run DeepMod2 with mutliple cores, and use at least 4 bam_threads for compressing BAM file.',type=int, default=4)
47 |
48 | detect_parser.add_argument("--ref", help='Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used.', type=str)
49 |
50 |
51 | detect_required.add_argument("--model", help='Name of the model. Recommended model for R9.4.1 flowcells is "bilstm_r9.4.1", for R10.4.1 flowcell (5kHz sampling) it is "bilstm_r10.4.1_5khz_v4.3", and for R10.4.1 flowcell (4kHz sampling) it is "bilstm_r10.4.1_4khz_v4.1". Use --print_models to display all models available along with compatible basecaller models. For custom models, provide the model config file and model checkpoint path separated by comma, e.g. "model.cfg,modelfile"',type=str, required=True)
52 |
53 | detect_required.add_argument("--bam", help='Path to aligned or unaligned BAM file. It is ideal to have move table in BAM file but move table from FAST5 fies can also be used. Aligned BAM file is required for reference anchored methylation calls, otherwise only the motif loci on reads will be called.', type=str, required=True)
54 |
55 | detect_required.add_argument("--file_type", help='Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file.',choices=['fast5','pod5'], type=str,required=True)
56 |
57 | detect_required.add_argument("--input", help='Path to POD5/FAST5 file or folder containing POD5/FAST5 files. If folder provided, then POD5/FAST5 files will be recusrviely searched', type=str, required=True)
58 |
59 | detect_parser.add_argument("--guppy_group", help='Name of the guppy basecall group if move table is in FAST5 file.',type=str, default='Basecall_1D_000')
60 | detect_parser.add_argument("--chrom", nargs='*', help='A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used.')
61 |
62 | detect_parser.add_argument("--fast5_move", help='Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the FAST5 files contains move table.', default=False, action='store_true')
63 |
64 | detect_parser.add_argument("--skip_per_site", help='Skip per site output', default=False, action='store_true')
65 | detect_parser.add_argument("--device", help='Device to use for running pytorch models. you can set --device=cpu for cpu, or --device=cuda for GPU. You can also specify a particular GPU device such as --device=cuda:0 or --device=cuda:1 . If --device paramater is not set by user, then GPU will be used if available otherwise CPU will be used.', type=str)
66 | detect_parser.add_argument("--disable_pruning", help='Disable model pruning (not recommended for CPU inference). By default models are pruned to remove some weights with low L1 norm in linear layers. Pruning has little effect on model accuracy, it can signifcantly improve CPU inference time but not GPU inference time.', default=False, action='store_true')
67 |
68 | detect_parser.add_argument("--exclude_ref_features", help='Exclude reference sequence from feature matrix. By default, if a reference FASTA file is provided via --ref parameter, then the reference sequence is added as a feature for aligned reads, but not if a read is unmapped or if no reference is provided.', default=False, action='store_true')
69 | detect_parser.add_argument("--batch_size", help='Batch size to use for GPU inference. For CPU inference, batch size is fixed at 512.',type=int, default=1024)
70 |
71 | detect_parser.add_argument("--bam_threads", help='Number of threads to use for compressed BAM output. Setting it lower than 4 can significantly lower the runtime.',type=int, default=4)
72 | detect_parser.add_argument("--skip_unmapped", help='Skip unmapped reads from methylation calling. If --chrom is used then unmapped are automatically skipped.', default=False, action='store_true')
73 |
74 | merge_parser = main_subparsers.add_parser("merge", parents=[parent_parser],
75 | add_help=True,
76 | help="Merge per-read calls into per-site calls", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
77 |
78 | merge_parser.add_argument("--input", nargs='*', help= 'List of paths of per-read methylation calls to merge. File paths should be separated by space/whitespace. Use either --input or --list argument, but not both.')
79 | merge_parser.add_argument("--list", help= 'A file containing paths to per-read methylation calls to merge (one per line). Use either --inputs or --list argument, but not both.', type=str)
80 | merge_parser.add_argument("--cpg_output", help= 'Create an additional per-site output file with forward and negative strand counts for CpG sites combined.', default=False, action='store_true')
81 |
82 | if len(sys.argv)==1:
83 | parser.print_help()
84 | parser.exit()
85 |
86 |
87 | elif len(sys.argv)==2:
88 | if sys.argv[1]=='merge':
89 | merge_parser.print_help()
90 | merge_parser.exit()
91 |
92 | elif sys.argv[1]=='detect':
93 | detect_parser.print_help()
94 | detect_parser.exit()
95 |
96 | args = parser.parse_args()
97 |
98 |
99 | if args.print_models:
100 | utils.get_model_help()
101 | parser.exit()
102 |
103 | t=time.time()
104 |
105 | print('%s: Starting DeepMod2.' %str(datetime.datetime.now()), flush=True)
106 |
107 | if not args.output:
108 | args.output=os.getcwd()
109 |
110 | os.makedirs(args.output, exist_ok=True)
111 |
112 | if args.option=='merge':
113 | if args.input:
114 | input_list= args.input
115 |
116 | elif args.list:
117 | with open(args.list,'r') as file_list:
118 | input_list=[x.rstrip('\n') for x in file_list.readlines()]
119 |
120 | params={'output':args.output, 'prefix':args.prefix, 'qscore_cutoff':args.qscore_cutoff,
121 | 'length_cutoff':args.length_cutoff, 'mod_t':args.mod_t,
122 | 'unmod_t':args.unmod_t,'include_non_cpg_ref':args.include_non_cpg_ref, 'cpg_output':args.cpg_output}
123 |
124 | site_pred_file=utils.get_per_site(params, input_list)
125 |
126 | else:
127 | if args.chrom:
128 | args.skip_unmapped=True
129 | chrom_list=args.chrom
130 | else:
131 | chrom_list=pysam.AlignmentFile(args.bam,'rb',check_sq=False).references
132 |
133 | if args.device:
134 | dev=args.device
135 | else:
136 | if torch.cuda.is_available():
137 | dev = "cuda"
138 | else:
139 | dev = "cpu"
140 |
141 |
142 | motif_seq, exp_motif_seq, motif_ind, valid_motif=utils.motif_check(args.motif)
143 | if not valid_motif:
144 | sys.exit(3)
145 |
146 | params={'input':args.input, 'output':args.output, 'threads':args.threads,
147 | 'prefix':args.prefix, 'model':args.model,
148 | 'qscore_cutoff':args.qscore_cutoff, 'ref':args.ref,
149 | 'length_cutoff':args.length_cutoff, 'bam':args.bam,
150 | 'file_type':args.file_type, 'fast5_move':args.fast5_move,
151 | 'guppy_group':args.guppy_group,
152 | 'mod_t':args.mod_t, 'unmod_t':args.unmod_t, 'include_non_cpg_ref': args.include_non_cpg_ref,
153 | 'skip_per_site':args.skip_per_site, 'chrom_list':chrom_list, "dev":dev,
154 | 'disable_pruning':args.disable_pruning, 'batch_size':args.batch_size,
155 | 'exclude_ref_features':args.exclude_ref_features,'bam_threads':args.bam_threads,
156 | 'skip_unmapped':args.skip_unmapped, 'mod_positions':args.mod_positions,
157 | 'motif_seq':motif_seq, 'motif_ind':motif_ind,'exp_motif_seq':exp_motif_seq,
158 | 'reference_motif_only':args.reference_motif_only,
159 | 'seq_type':args.seq_type,
160 | 'mod_symbol':args.mod_symbol
161 | }
162 |
163 | print('\n%s: \nCommand: python %s\n' %(str(datetime.datetime.now()), ' '.join(sys.argv)), flush=True)
164 |
165 | with open(os.path.join(args.output,'args'),'w') as file:
166 | file.write('Command: python %s\n\n\n' %(' '.join(sys.argv)))
167 | file.write('------Parameters Used For Running DeepMod2------\n')
168 | for k in vars(args):
169 | file.write('{}: {}\n'.format(k,vars(args)[k]) )
170 |
171 | from src import detect
172 | detect.call_manager(params)
173 |
174 | print('\n%s: Time elapsed=%.4fs' %(str(datetime.datetime.now()),time.time()-t), flush=True)
175 |
--------------------------------------------------------------------------------
/docs/Example.md:
--------------------------------------------------------------------------------
1 | # DeepMod2 Run Example
2 |
3 | This example shows how to use DeepMod2 to prediction 5mC methylation from FAST5 files. We will use sample dataset from https://github.com/WGLab/DeepMod2/releases/tag/v0.3.0.
4 |
5 | - [1. Methylation Calling from POD5 files with Dorado basecalling](Example.md#1-methylation-calling-from-pod5-files-with-dorado-basecalling)
6 | - [1.1 Reference Anchored Methylation Calling](Example.md#11-reference-anchored-methylation-calling)
7 | - [1.1.1 Dorado Basecalling and Read Alignment to Reference Genome](Example.md#111-dorado-basecalling-and-read-alignment-to-reference-genome)
8 | - [1.1.2 (Optional) Read Phasing for diploid genomes](Example.md#112-optional-read-phasing-for-diploid-genomes)
9 | - [1.1.3 Methylation Calling with DeepMod2](Example.md#113-methylation-calling-with-deepmod2)
10 | - [1.1.4 Visualizing DeepMod2 Methylation in IGV](Example.md#114-visualizing-deepmod2-methylation-in-igv)
11 | - [1.2 Reference free methylation calling](Example.md#12-reference-free-methylation-calling)
12 | - [1.2.1 Dorado Basecalling ](Example.md#121-dorado-basecalling)
13 | - [1.2.2 Reference Free Methylation Calling with DeepMod2](Example.md#122-reference-free-methylation-calling-with-deepmod2)
14 | - [1.2.3 Optional Read Alignment to Reference Genome and Per-site frequency calculation with modkit](Example.md#123-optional-read-alignment-to-reference-genome-and-per-site-frequency-calculation-with-modkit)
15 | - [2. Methylation Calling from FAST5 files with Guppy basecalling](Example.md#2-methylation-calling-from-fast5-files-with-guppy-basecalling)
16 | - [2.1 Reference Anchored Methylation Calling](Example.md#21-reference-anchored-methylation-calling)
17 | - [2.1.1 Guppy Basecalling and Read Alignment to Reference Genome](Example.md#211-guppy-basecalling-and-read-alignment-to-reference-genome)
18 | - [2.1.2 Methylation Calling with DeepMod2 and Guppy Basecalling](Example.md#212-methylation-calling-with-deepmod2-and-guppy-basecalling)
19 | - [2.2 Reference free methylation calling with Guppy Basecalling](Example.md#22-reference-free-methylation-calling-with-guppy-basecalling)
20 |
21 | # 1. Methylation Calling from POD5 files with Dorado basecalling
22 | **Prepare Directories**
23 | ```
24 | INPUT_DIR=data
25 | OUTPUT_DIR=mod
26 |
27 | mkdir -p ${INPUT_DIR}
28 | mkdir -p ${OUTPUT_DIR}
29 | ```
30 | **Download Software Packges**
31 | ```
32 | # Install DeepMod2
33 | git clone https://github.com/WGLab/DeepMod2.git ${INPUT_DIR}/DeepMod2
34 | conda env create -f ${INPUT_DIR}/DeepMod2/environment.yml
35 | conda activate deepmod2
36 | conda install samtools minimap2 bedtools -y
37 |
38 | # Download Dorado Basecaller and model
39 | wget -qO- https://cdn.oxfordnanoportal.com/software/analysis/dorado-0.5.3-linux-x64.tar.gz | tar xzf - -C ${INPUT_DIR}
40 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado download --model dna_r10.4.1_e8.2_400bps_hac@v4.3.0 --directory ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/
41 | ```
42 |
43 | **Download Nanopore data and reference genome**
44 | ```
45 | # Download reference genome
46 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.gz -O -| gunzip -c > ${INPUT_DIR}/GRCh38.fa
47 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.fai -O ${INPUT_DIR}/GRCh38.fa.fai
48 |
49 | # Download POD5 files
50 | mkdir -p ${INPUT_DIR}/nanopore_raw_data
51 |
52 | wget -qO- https://github.com/WGLab/DeepMod2/files/14368872/sample.pod5.tar.gz| tar xzf - -C ${INPUT_DIR}/nanopore_raw_data
53 | ```
54 |
55 | ## 1.1 Reference Anchored Methylation Calling
56 | In order to perform reference anchored methylation calling, we will provide an aligned BAM to DeepMod2. In this case, DeepMod2 will detect 5mC in all CpG motifs found on the read, as wel as any bases of the read that map to a reference CpG site. Finally, it will combine per-read predictions for a given reference CpG site into a per-site methylation frequency. Any unaligned reads or unaligned segments of the reads will also be analyzed for 5mC and reported in per-read output and BAM file, but would not be used in per-site frequency calculation.
57 |
58 | ### 1.1.1 Dorado Basecalling and Read Alignment to Reference Genome
59 | First we will perform basecalling of our nanopore signal file using Dorado basecaller. It is possible to align the reads during basecalling or align the reads after basecalling. Both options are shown below. Since we need move table for our basecalled DNA sequences, we will use `--emit-moves` while running Dorado, which will produce an aligned (Option A) or unaligned BAM file (Option B).
60 |
61 | #### Option A: Perform Read Alignment during Bascalling with Dorado
62 |
63 | Dorado has the option to perform read alignment using minimap2 during basecalling if a reference FASTA file is provided as `--reference` option. This can be be helpful in reducing the number of steps needed to run.
64 |
65 | ```
66 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado basecaller --emit-moves --recursive --reference ${INPUT_DIR}/GRCh38.fa ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/dna_r10.4.1_e8.2_400bps_hac@v4.3.0 ${INPUT_DIR}/nanopore_raw_data > ${OUTPUT_DIR}/aligned.bam
67 | ```
68 |
69 | This will produce an aligned BAM file named `aligned.bam` under the `$OUTPUT_DIR` folder.
70 |
71 | #### Option B: Perform Read Alignment after Bascalling with Dorado
72 |
73 | It is possible to run Dorado basecaller without performing alignment. This can be helpful in speeding up basecalling process that requires the use of a GPU instance which can be expensive. It also allows you more flexibility in terms of how you want to perform alignment, with specific minimap2 parameters.
74 |
75 | **Basecalling with Dorado**
76 |
77 | ```
78 | # Perform basecalling
79 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado basecaller --emit-moves --recursive ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/dna_r10.4.1_e8.2_400bps_hac@v4.3.0 ${INPUT_DIR}/nanopore_raw_data > ${OUTPUT_DIR}/basecalled.bam
80 | ```
81 |
82 | This will produce an unaligned BAM file named `basecalled.bam` under the `$OUTPUT_DIR` folder.
83 |
84 | **Alignment with minimap2**
85 |
86 | We will convert this BAM file into FASTQ format while keeping all the tags and pipe into minimap2 for alignment.
87 |
88 | ```
89 | # Align using minimap2 while copying move table information
90 | samtools fastq ${OUTPUT_DIR}/basecalled.bam -T "*"|minimap2 -ax map-ont ${INPUT_DIR}/GRCh38.fa - -y|samtools view -o ${OUTPUT_DIR}/aligned.bam
91 | ```
92 |
93 | This will produce an aligned BAM file named `aligned.bam` under the `$OUTPUT_DIR` folder.
94 |
95 | ### 1.1.2 (Optional) Read Phasing for diploid genomes
96 | You can optionally use SNP calling and haplotyping tool such as NanoCaller or Clair3 to phase the BAM file into parental haplotypes. The phased BAM file can be provided as input to DeepMod2 instead of `${OUTPUT_DIR}/aligned.bam` to get haplotype specific methylation calls.
97 |
98 | ```
99 | #install NanoCaller
100 | conda install -c bioconda NanoCaller
101 |
102 | #sort and index the BAM file
103 | samtools sort ${OUTPUT_DIR}/aligned.bam -o ${OUTPUT_DIR}/aligned.sorted.bam
104 | samtools index ${OUTPUT_DIR}/aligned.sorted.bam
105 |
106 | #Run NanoCaller to phase the reads
107 | NanoCaller --bam ${OUTPUT_DIR}/aligned.sorted.bam --ref ${INPUT_DIR}/GRCh38.fa --mode snps --phase --output ${OUTPUT_DIR}/nanocaller --wgs_contigs chr1-22XY --cpu 8
108 |
109 | # Merge phased reads into a single BAM file
110 | find ${OUTPUT_DIR}/nanocaller/intermediate_phase_files -type f -name '*bam'|samtools cat -b - -o ${OUTPUT_DIR}/phased.bam
111 |
112 | ```
113 |
114 | ### 1.1.3 Methylation Calling with DeepMod2
115 | Now we will run DeepMod2's `detect` module using `bilstm_r10.4.1_5khz_v4.3` model and use the aligned BAM file and Nanopore signal files as input. Since we want to perform reference anchored methylation calling, we will provide the reference genome FASTA file as input as well. We will use the phased BAM file from the previous step, but you can also use `${OUTPUT_DIR}/aligned.bam` BAM file if you do not want to get haplotype specific methylation calls.
116 |
117 | ```
118 | # Run DeepMod2
119 | BAM_INPUT=${OUTPUT_DIR}/phased.bam # Use ${OUTPUT_DIR}/aligned.bam if you did not use NanoCaller to phase the reads
120 | python ${INPUT_DIR}/DeepMod2/deepmod2 detect --model bilstm_r10.4.1_5khz_v4.3 --file_type pod5 --bam $BAM_INPUT --input ${INPUT_DIR}/nanopore_raw_data --output ${OUTPUT_DIR}/deepmod2/ --ref ${INPUT_DIR}/GRCh38.fa --threads 8
121 | ```
122 | The output folder of DeepMod2 `${OUTPUT_DIR}/deepmod2/` will contain the following files:
123 | ```
124 | args -> Shows the arguments and command use to run DeepMod2
125 | output.bam -> Unsorted methylation tagged BAM file
126 | output.per_read -> Per-read methylation calls in sorted BED file
127 | output.per_site -> Per-site methylation calls for +- strands separately in sorted BED file.
128 | output.per_site.aggregated -> Per-site methylation calls for with counts for +- strands combined.
129 | ```
130 |
131 | **Per-read Output**
132 | We will inspect contents of the per-read output file using `head ${OUTPUT_DIR}/deepmod2/output.per_read`:
133 |
134 |
135 | |read_name|chromosome|ref_position_before|ref_position|read_position|strand|methylation_score|mean_read_qscore|read_length|read_phase|ref_cpg|
136 | |-|-|-|-|-|-|-|-|-|-|-|
137 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733569|2733570|35|-|0.0036|18.05|48187|1|TRUE|
138 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733457|2733458|146|-|0.9846|18.05|48187|1|TRUE|
139 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733439|2733440|164|-|0.0048|18.05|48187|1|TRUE|
140 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733362|2733363|242|-|0.2352|18.05|48187|1|TRUE|
141 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733356|2733357|248|-|0.841|18.05|48187|1|TRUE|
142 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733351|2733352|253|-|0.9893|18.05|48187|1|FALSE|
143 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733341|2733342|263|-|0.012|18.05|48187|1|TRUE|
144 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733155|2733156|449|-|0.9858|18.05|48187|1|TRUE|
145 | |160f871b-f4c3-40de-a160-383fcd5033e7|chr11|2733143|2733144|461|-|0.966|18.05|48187|1|TRUE|
146 |
147 |
148 | **Per-Site Output**
149 |
150 | We will use `bedtools intersect` to inspect per-site methylation frequencies in chr11:2699000-2702000 imprinting control region from the per-site output file with stranded CpG counts:
151 |
152 | ```
153 | printf 'chr11\t2699000\t2702000'|bedtools intersect -header -a ${OUTPUT_DIR}/deepmod2/output.per_site -b -|head
154 | ```
155 |
156 | |#chromosome|position_before|position|strand|ref_cpg|coverage|mod_coverage|unmod_coverage|mod_percentage|coverage_phase1|mod_coverage_phase1|unmod_coverage_phase1|mod_percentage_phase1|coverage_phase2|mod_coverage_phase2|unmod_coverage_phase2|mod_percentage_phase2|
157 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
158 | |chr11|2699031|2699032|+|TRUE|16|8|8|0.5|8|0|8|0|8|8|0|1|
159 | |chr11|2699032|2699033|-|TRUE|11|7|4|0.6364|4|0|4|0|7|7|0|1|
160 | |chr11|2699037|2699038|+|TRUE|16|8|8|0.5|8|0|8|0|8|8|0|1|
161 | |chr11|2699038|2699039|-|TRUE|11|7|4|0.6364|4|0|4|0|7|7|0|1|
162 | |chr11|2699048|2699049|+|TRUE|16|6|10|0.375|8|0|8|0|8|6|2|0.75|
163 | |chr11|2699049|2699050|-|TRUE|11|5|6|0.4545|4|0|4|0|7|5|2|0.7143|
164 | |chr11|2699099|2699100|+|TRUE|15|7|8|0.4667|8|0|8|0|7|7|0|1|
165 | |chr11|2699100|2699101|-|TRUE|11|7|4|0.6364|4|0|4|0|7|7|0|1|
166 | |chr11|2699101|2699102|+|TRUE|16|7|9|0.4375|8|0|8|0|8|7|1|0.875|
167 |
168 |
169 | **Aggregated Per-Site Output**
170 |
171 | We will use `bedtools intersect` to inspect per-site methylation frequencies in chr11:2699000-2702000 imprinting control region from the per-site output file with aggregated CpG counts over +- strands:
172 |
173 | ```
174 | printf 'chr11\t2699000\t2702000'|bedtools intersect -header -a ${OUTPUT_DIR}/deepmod2/output.per_site.aggregated -b -|head
175 | ```
176 |
177 | |#chromosome|position_before|position|ref_cpg|coverage|mod_coverage|unmod_coverage|mod_percentage|coverage_phase1|mod_coverage_phase1|unmod_coverage_phase1|mod_percentage_phase1|coverage_phase2|mod_coverage_phase2|unmod_coverage_phase2|mod_percentage_phase2|
178 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
179 | |chr11|2699031|2699033|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1|
180 | |chr11|2699037|2699039|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1|
181 | |chr11|2699048|2699050|TRUE|27|11|16|0.4074|12|0|12|0|15|11|4|0.7333|
182 | |chr11|2699099|2699101|TRUE|26|14|12|0.5385|12|0|12|0|14|14|0|1|
183 | |chr11|2699101|2699103|TRUE|27|14|13|0.5185|12|0|12|0|15|14|1|0.9333|
184 | |chr11|2699115|2699117|TRUE|27|14|13|0.5185|12|0|12|0|15|14|1|0.9333|
185 | |chr11|2699120|2699122|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1|
186 | |chr11|2699180|2699182|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1|
187 | |chr11|2699187|2699189|TRUE|27|15|12|0.5556|12|0|12|0|15|15|0|1|
188 |
189 | These results show that phase 1 is completely unmodified (column mod_percentage_phase1) whereas phase 2 is nearly completely modified (mod_percentage_phase2), which is what we expect for this imprinted region.
190 |
191 |
192 | ### 1.1.4 Visualizing DeepMod2 Methylation in IGV
193 | Since the methylation tagged BAM file produced by DeepMod2 is not sorted, we will first sort and index it:
194 |
195 | ```
196 | samtools sort ${OUTPUT_DIR}/deepmod2/output.bam -o ${OUTPUT_DIR}/deepmod2/output.sorted.bam --write-index
197 | ```
198 |
199 | Open the BAM file `${OUTPUT_DIR}/deepmod2/output.sorted.bam` in IGV, select `Color alignments by base modificaition (5mC)`. If you used phased BAM file for methylation, you can select `Group alignments by phase` to separate reads by haplotype. Go to the region `chr11:2699000-2702000` to see the following methylation tags:
200 |
201 | 
202 |
203 |
204 | ## 1.2 Reference free methylation calling
205 | In order to perform reference free methylation calling, we will provide an unaligned BAM to DeepMod2. In this case, DeepMod2 will detect 5mC in all CpG motifs found on the read only and will not use reference sequence as a feature. In this case, per-read predictions will not be combined into a per-site methylation frequency, and methylation will be reported only in per-read output and BAM file.
206 |
207 | ### 1.2.1 Dorado Basecalling
208 |
209 | ```
210 | # Perform basecalling
211 | ${INPUT_DIR}/dorado-0.5.3-linux-x64/bin/dorado basecaller --emit-moves --recursive ${INPUT_DIR}/dorado-0.5.3-linux-x64/models/dna_r10.4.1_e8.2_400bps_hac@v4.3.0 ${INPUT_DIR}/nanopore_raw_data > ${OUTPUT_DIR}/basecalled.bam
212 | ```
213 |
214 | This will produce an unaligned BAM file named `basecalled.bam` under the `$OUTPUT_DIR` folder.
215 |
216 | ### 1.2.2 Reference Free Methylation Calling with DeepMod2
217 |
218 | Now we will run DeepMod2's `detect` module using `bilstm_r10.4.1_5khz_v4.3` model and use the unaligned BAM file and Nanopore signal files as input. In this situation, we will not provide the reference genome FASTA file as input.
219 |
220 | ```
221 | # Run DeepMod2
222 | BAM_INPUT=${OUTPUT_DIR}/basecalled.bam
223 | python ${INPUT_DIR}/DeepMod2/deepmod2 detect --model bilstm_r10.4.1_5khz_v4.3 --file_type pod5 --bam $BAM_INPUT --input ${INPUT_DIR}/nanopore_raw_data --output ${OUTPUT_DIR}/deepmod2/ --threads 8
224 | ```
225 |
226 | The output folder of DeepMod2 `${OUTPUT_DIR}/deepmod2/` will contain the following files:
227 | ```
228 | args -> Shows the arguments and command use to run DeepMod2
229 | output.bam -> Unsorted methylation tagged BAM file
230 | output.per_read -> Per-read methylation calls in sorted BED file
231 | output.per_site -> Per-site methylation file will be empty.
232 | output.per_site.aggregated -> Aggregated Per-site methylation file will be empty.
233 | ```
234 |
235 | **Per-read Output**
236 | We will inspect contents of the per-read output file using `head ${OUTPUT_DIR}/deepmod2/output.per_read`:
237 |
238 |
239 | |read_name|chromosome|ref_position_before|ref_position|read_position|strand|methylation_score|mean_read_qscore|read_length|read_phase|ref_cpg|
240 | |-|-|-|-|-|-|-|-|-|-|-|
241 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|35|+|0.0042|18.12|48175|0|FALSE|
242 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|146|+|0.985|18.12|48175|0|FALSE|
243 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|164|+|0.0057|18.12|48175|0|FALSE|
244 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|242|+|0.36|18.12|48175|0|FALSE|
245 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|248|+|0.9428|18.12|48175|0|FALSE|
246 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|253|+|0.9935|18.12|48175|0|FALSE|
247 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|263|+|0.0123|18.12|48175|0|FALSE|
248 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|449|+|0.9866|18.12|48175|0|FALSE|
249 | |160f871b-f4c3-40de-a160-383fcd5033e7|NA|NA|NA|461|+|0.968|18.12|48175|0|FALSE|
250 |
251 |
252 | In this case the chromosome, position and strand information is not available since the reads were unaligned.
253 |
254 | ### 1.2.3 Optional Read Alignment to Reference Genome and Per-site frequency calculation with modkit
255 | The unaligned BAM file produced by DeepMod2 contains 5mC tags MM and ML for all CpG motifs in a read. Reads from this BAM file can be aligned to any reference genome and methylation counts for reads mapping to the same reference CpG sites can be obtained using [modkit](https://github.com/nanoporetech/modkit).
256 |
257 | ```
258 | # Align reads using minimap2 and then sort and index the BAM file
259 | samtools fastq ${OUTPUT_DIR}/deepmod2/output.bam -T "*"|minimap2 -ax map-ont ${INPUT_DIR}/GRCh38.fa - -y|samtools sort -o ${OUTPUT_DIR}/deepmod2/aligned.sorted.bam --write-index
260 |
261 | modkit pileup ${OUTPUT_DIR}/deepmod2/aligned.sorted.bam ${OUTPUT_DIR}/deepmod2/per_site_methylation.bed --ref ${INPUT_DIR}/GRCh38.fa --preset traditional
262 | ```
263 |
264 | Per-site frequency calculation of modkit will be stored in `${OUTPUT_DIR}/deepmod2/per_site_methylation.bed` file in the column format described [here](https://github.com/nanoporetech/modkit?tab=readme-ov-file#bedmethyl-column-descriptions). We will use `bedtools intersect` to inspect per-site methylation frequencies in chr11:2699000-2702000 imprinting control region from the per-site output file with aggregated CpG counts over +- strands:
265 |
266 | ```
267 | printf 'chr11\t2699900\t27002000'|bedtools intersect -header -a ${OUTPUT_DIR}/deepmod2/per_site_methylation.bed -b -|head
268 | ```
269 |
270 | |chrom|start|end|code score|score|strand|start|end|color|N_valid_cov|fraction_modified|N_mod|N_canonical|N_other_mod|N_delete|N_fail|N_diff|N_nocall|
271 | |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
272 | |chr11|2699906|2699907|m|25|.|2699906|2699907|255,0,0|25|60|15|10|0|0|0|1|2|
273 | |chr11|2699915|2699916|m|27|.|2699915|2699916|255,0,0|27|55.56|15|12|0|0|0|0|1|
274 | |chr11|2699917|2699918|m|28|.|2699917|2699918|255,0,0|28|57.14|16|12|0|0|0|0|0|
275 | |chr11|2699927|2699928|m|27|.|2699927|2699928|255,0,0|27|55.56|15|12|0|0|0|0|1|
276 | |chr11|2699966|2699967|m|28|.|2699966|2699967|255,0,0|28|57.14|16|12|0|0|0|0|0|
277 | |chr11|2699976|2699977|m|27|.|2699976|2699977|255,0,0|27|55.56|15|12|0|0|1|0|0|
278 | |chr11|2699979|2699980|m|28|.|2699979|2699980|255,0,0|28|57.14|16|12|0|0|0|0|0|
279 | |chr11|2699981|2699982|m|28|.|2699981|2699982|255,0,0|28|57.14|16|12|0|0|0|0|0|
280 | |chr11|2699984|2699985|m|27|.|2699984|2699985|255,0,0|27|55.56|15|12|0|0|1|0|0|
281 | |chr11|2699988|2699989|m|27|.|2699988|2699989|255,0,0|27|55.56|15|12|0|0|1|0|0|
282 |
283 | Even though this result does not show haplotype specific methylation, we can see that the modified fraction is ~50-60% as expected.
284 |
285 |
286 | # 2. Methylation Calling from FAST5 files with Guppy basecalling
287 | **Prepare Directories**
288 | ```
289 | INPUT_DIR=data
290 | OUTPUT_DIR=mod
291 |
292 | mkdir -p ${INPUT_DIR}
293 | mkdir -p ${OUTPUT_DIR}
294 | ```
295 | **Download Software Packges**
296 | ```
297 | # Install DeepMod2
298 | git clone https://github.com/WGLab/DeepMod2.git ${INPUT_DIR}/DeepMod2
299 | conda env create -f ${INPUT_DIR}/DeepMod2/environment.yml
300 | conda activate deepmod2
301 | conda install samtools minimap2 bedtools -y
302 |
303 | # Download Guppy basecaller
304 | wget -qO- https://cdn.oxfordnanoportal.com/software/analysis/ont-guppy_6.5.7_linux64.tar.gz| tar xzf - -C ${INPUT_DIR}
305 | ```
306 |
307 | **Download Nanopore data and reference genome**
308 | ```
309 | # Download reference genome
310 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.gz -O -| gunzip -c > ${INPUT_DIR}/GRCh38.fa
311 | wget ftp://ftp.ncbi.nlm.nih.gov/genomes/all/GCF/000/001/405/GCF_000001405.39_GRCh38.p13/GRCh38_major_release_seqs_for_alignment_pipelines/GCA_000001405.15_GRCh38_no_alt_plus_hs38d1_analysis_set.fna.fai -O ${INPUT_DIR}/GRCh38.fa.fai
312 |
313 | # Download FAST5 files
314 | mkdir -p ${INPUT_DIR}/nanopore_raw_data
315 |
316 | wget -qO- https://github.com/WGLab/DeepMod2/files/14368873/sample.fast5.tar.gz| tar xzf - -C ${INPUT_DIR}/nanopore_raw_data
317 | ```
318 |
319 | ## 2.1 Reference Anchored Methylation Calling
320 | In order to perform reference anchored methylation calling, we will provide an aligned BAM to DeepMod2. In this case, DeepMod2 will detect 5mC in all CpG motifs found on the read, as wel as any bases of the read that map to a reference CpG site. Finally, it will combine per-read predictions for a given reference CpG site into a per-site methylation frequency. Any unaligned reads or unaligned segments of the reads will also be analyzed for 5mC and reported in per-read output and BAM file, but would not be used in per-site frequency calculation.
321 |
322 | ### 2.1.1 Guppy Basecalling and Read Alignment to Reference Genome
323 | First we will perform basecalling of our nanopore signal file using Guppy basecaller. It is possible to align the reads during basecalling or align the reads after basecalling. Both options are shown below. Since we need move table for our basecalled DNA sequences, we will use `--moves_out` while running Guppy, which will produce an aligned (Option A) or unaligned BAM file (Option B).
324 |
325 | #### Option A: Perform Read Alignment during Bascalling with Guppy
326 |
327 | Guppy has the option to perform read alignment using minimap2 during basecalling if a reference FASTA file is provided as `--align_ref` option. This can be be helpful in reducing the number of steps needed to run.
328 |
329 | ```
330 | ${INPUT_DIR}/ont-guppy/bin/guppy_basecaller -i ${INPUT_DIR}/nanopore_raw_data -s ${OUTPUT_DIR}/basecalls --align_ref ${INPUT_DIR}/GRCh38.fa -c dna_r10.4.1_e8.2_400bps_5khz_hac.cfg --bam_out --moves_out --recursive
331 | ```
332 |
333 | This will produce several aligned BAM files under the `${OUTPUT_DIR}/basecalls` subfolders `pass` and `fail`. We will combine these BAM files into a single BAM file to give as input to DeepMod2, and we have a choice of using both pass and fail reads or just pass reads.
334 |
335 | ```
336 | find ${OUTPUT_DIR}/basecalls \( -path "*/pass/*" -o -path "*/fail/*" \) -type f -name "*.bam"|samtools cat -b - -o ${OUTPUT_DIR}/aligned.bam
337 | ```
338 |
339 | #### Option B: Perform Read Alignment after Bascalling with Guppy
340 |
341 | It is possible to run Guppy basecaller without performing alignment. This can be helpful in speeding up basecalling process that requires the use of a GPU instance which can be expensive. It also allows you more flexibility in terms of how you want to perform alignment, with specific minimap2 parameters.
342 |
343 | **Basecalling with Guppy**
344 |
345 | ```
346 | # Perform basecalling
347 | ${INPUT_DIR}/ont-guppy/bin/guppy_basecaller -i ${INPUT_DIR}/nanopore_raw_data -s ${OUTPUT_DIR}/basecalls -c dna_r10.4.1_e8.2_400bps_5khz_hac.cfg --bam_out --moves_out --recursive
348 | ```
349 |
350 | This will produce several aligned BAM files under the `${OUTPUT_DIR}/basecalls` subfolders `pass` and `fail`. We will align read sequences from these BAM files and we have a choice of using both pass and fail reads or just pass reads.
351 |
352 | **Alignment with minimap2**
353 |
354 | We will convert this BAM file into FASTQ format while keeping all the tags and pipe into minimap2 for alignment.
355 |
356 | ```
357 | # Align using minimap2 while copying move table information
358 |
359 | find ${OUTPUT_DIR}/basecalls \( -path "*/pass/*" -o -path "*/fail/*" \) -type f -name "*.bam"|samtools cat -b -|samtools fastq - -T "*"|minimap2 -ax map-ont ${INPUT_DIR}/GRCh38.fa - -y|samtools view -o ${OUTPUT_DIR}/aligned.bam
360 | ```
361 |
362 | This will produce an aligned BAM file named `aligned.bam` under the `$OUTPUT_DIR` folder.
363 |
364 | ---------
365 | ### 2.1.2 Methylation Calling with DeepMod2 and Guppy Basecalling
366 | Once you have an aligned BAM file, you can continue with DeepMod2 methylation calling as described in [1.1.3 Methylation Calling with DeepMod2](Example.md#113-methylation-calling-with-deepmod2), you would only need to change `--file_type` parameter to `fast5` since we are using FAST5 file format.
367 |
368 | ## 2.2 Reference free methylation calling with Guppy Basecalling
369 | You can perform Guppy basecalling without reference alignment:
370 |
371 | ```
372 | # Perform basecalling
373 | ${INPUT_DIR}/ont-guppy/bin/guppy_basecaller -i ${INPUT_DIR}/nanopore_raw_data -s ${OUTPUT_DIR}/basecalls -c dna_r10.4.1_e8.2_400bps_5khz_hac.cfg --bam_out --moves_out --recursive
374 | ```
375 |
376 | and continue with methylation calling as described in [1.2.2 Reference Free Methylation Calling with DeepMod2](Example.md#122-reference-free-methylation-calling-with-deepmod2).
377 |
--------------------------------------------------------------------------------
/docs/Install.md:
--------------------------------------------------------------------------------
1 | ## Conda Installation
2 | First, install Miniconda, a minimal installation of Anaconda, which is much smaller and has a faster installation:
3 | ```
4 | curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
5 | bash Miniconda3-latest-Linux-x86_64.sh
6 | ```
7 | Go through all the prompts (installation in `$HOME` is recommended). After Anaconda is installed successfully, simply run:
8 | ```
9 | git clone https://github.com/WGLab/DeepMod2.git
10 | conda env create -f DeepMod2/environment.yml
11 | conda activate deepmod2
12 | ```
13 |
14 | After installing, run `python DeepMod2/deepmod2 --help` to see the run options.
15 |
16 | ### Installation for GPU
17 | If you want to use GPU to accelerate DeepMod2, make sure to install cuda enabled version of pytorch. Details for GPU accelerated pytorch can be found here https://pytorch.org/get-started/locally/ and you can select the installation command best suited to your system. One xample of installation command is shown below for CUDA 11.8:
18 |
19 | `conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia`
20 |
--------------------------------------------------------------------------------
/docs/Usage.md:
--------------------------------------------------------------------------------
1 | # DeepMod2 Usage
2 |
3 |
4 | DeepMod2 has two run options `detect` and `merge` which can be shown using `python deepmod2 --help`
5 | ```
6 | usage: deepmod2 [-h] [--print_models] {detect,merge} ...
7 |
8 | options:
9 | -h, --help show this help message and exit
10 | --print_models Print details of models available (default: False)
11 |
12 | Options:
13 | {detect,merge}
14 | detect Call methylation from Guppy or Dorado basecalled POD5/FAST5 files using move tables for signal alignment.
15 | merge Merge per-read calls into per-site calls
16 |
17 | ```
18 |
19 | In short, `detect` is for methylation calling using Dorado and Guppy generated POD5/FAST5 files, respectively, whereas `merge` is for merging per-read methylation calls from several runs into per-site methylation calls.
20 |
21 | `python deepmod2 --print_models` shows names of available models as well as their description, such as datasets used for training. Names of these models can be used with `--model` argument to specify which model to use. Please use the appropriate DeepMod2 model for the flowcell you are using, e.g. R9.4.1 vs R10.4.1.
22 |
23 | ## detect-guppy
24 | DeepMod2 can detect 5mC methylation from Dorado and Guppy basecalled POD5/FAST5 signal files using `detect` option. In addition to signal files, you are required to provide a BAM file containing read sequences. It is recommended that the reads in the BAM file are aligned to a reference genome for more accurate results, in which case you should also provide a reference genome FASTA file. You are required to provide move tables for the signal files, so make sure you output move tables when you run the basecaller. The move table information can be part of the BAM file as mv and ts tags, or it can be part of FAST5 file (which is deprecated now).
25 |
26 | ```
27 | usage: deepmod2 detect [-h] [--prefix PREFIX] [--output OUTPUT] [--qscore_cutoff QSCORE_CUTOFF] [--length_cutoff LENGTH_CUTOFF] [--mod_t MOD_T] [--unmod_t UNMOD_T] [--include_non_cpg_ref]
28 | [--threads THREADS] [--ref REF] --model MODEL --bam BAM --file_type {fast5,pod5} --input INPUT [--guppy_group GUPPY_GROUP] [--chrom [CHROM ...]] [--fast5_move]
29 | [--skip_per_site] [--device DEVICE] [--disable_pruning] [--exclude_ref_features] [--batch_size BATCH_SIZE] [--bam_threads BAM_THREADS] [--skip_unmapped]
30 |
31 | options:
32 | -h, --help show this help message and exit
33 | --prefix PREFIX Prefix for the output files (default: output)
34 | --output OUTPUT Path to folder where intermediate and final files will be stored, default is current working directory (default: None)
35 | --qscore_cutoff QSCORE_CUTOFF
36 | Minimum cutoff for mean quality score of a read (default: 0)
37 | --length_cutoff LENGTH_CUTOFF
38 | Minimum cutoff for read length (default: 0)
39 | --mod_t MOD_T Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for
40 | calculation of per-site modification levels. (default: 0.5)
41 | --unmod_t UNMOD_T Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for
42 | calculation of per-site modification levels. (default: 0.5)
43 | --include_non_cpg_ref
44 | Include non-CpG reference loci in per-site output where reads have CpG motif. (default: False)
45 | --threads THREADS Number of threads to use for processing signal and running model inference. If a GPU is used for inference, then --threads number of threads will be running on GPU
46 | concurrently. The total number of threads used by DeepMod2 is equal to --threads plus --bam_threads. It is recommended to run DeepMod2 with mutliple cores, and use
47 | at least 4 bam_threads for compressing BAM file. (default: 4)
48 | --ref REF Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used. (default: None)
49 | --guppy_group GUPPY_GROUP
50 | Name of the guppy basecall group (default: Basecall_1D_000)
51 | --chrom [CHROM ...] A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used. (default: None)
52 | --fast5_move Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the
53 | FAST5 files contains move table. (default: False)
54 | --skip_per_site Skip per site output (default: False)
55 | --device DEVICE Device to use for running pytorch models. you can set --device=cpu for cpu, or --device=cuda for GPU. You can also specify a particular GPU device such as
56 | --device=cuda:0 or --device=cuda:1 . If --device paramater is not set by user, then GPU will be used if available otherwise CPU will be used. (default: None)
57 | --disable_pruning Disable model pruning (not recommended for CPU inference). By default models are pruned to remove some weights with low L1 norm in linear layers. Pruning has
58 | little effect on model accuracy, it can signifcantly improve CPU inference time but not GPU inference time. (default: False)
59 | --exclude_ref_features
60 | Exclude reference sequence from feature matrix. By default, if a reference FASTA file is provided via --ref parameter, then the reference sequence is added as a
61 | feature for aligned reads, but not if a read is unmapped or if no reference is provided. (default: False)
62 | --batch_size BATCH_SIZE
63 | Batch size to use for GPU inference. For CPU inference, batch size is fixed at 512. (default: 1024)
64 | --bam_threads BAM_THREADS
65 | Number of threads to use for compressed BAM output. Setting it lower than 4 can significantly lower the runtime. (default: 4)
66 | --skip_unmapped Skip unmapped reads from methylation calling (default: False)
67 |
68 | Required Arguments:
69 | --model MODEL Name of the model. Recommended model for R9.4.1 flowcells is "bilstm_r9.4.1", for R10.4.1 flowcell (4kHz sampling) it is "bilstm_r10.4.1_4khz". Use --print_models
70 | to display all models available. (default: None)
71 | --bam BAM Path to aligned or unaligned BAM file. It is ideal to have move table in BAM file but move table from FAST5 fies can also be used. Aligned BAM file is required for
72 | reference anchored methylation calls, otherwise only the motif loci on reads will be called. (default: None)
73 | --file_type {fast5,pod5}
74 | Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file. (default: None)
75 | --input INPUT Path to POD5/FAST5 file or folder containing POD5/FAST5 files. If folder provided, then POD5/FAST5 files will be recusrviely searched (default: None)
76 |
77 | ```
78 |
79 | ## merge
80 | For large datasets, it can be useful to split the data and run multiple instances of DeepMod2 for speedup. In this case, you can run `detect` with `--skip_per_site` option and use `deepmod2 merge` to merge per-read calls into per-site calls.
81 |
82 | ```
83 | usage: deepmod2 merge [-h] [--prefix PREFIX] [--output OUTPUT] [--qscore_cutoff QSCORE_CUTOFF] [--length_cutoff LENGTH_CUTOFF] [--mod_t MOD_T] [--unmod_t UNMOD_T] [--include_non_cpg_ref]
84 | [--input [INPUT ...]] [--list LIST]
85 |
86 | options:
87 | -h, --help show this help message and exit
88 | --prefix PREFIX Prefix for the output files (default: output)
89 | --output OUTPUT Path to folder where intermediate and final files will be stored, default is current working directory (default: None)
90 | --qscore_cutoff QSCORE_CUTOFF
91 | Minimum cutoff for mean quality score of a read (default: 0)
92 | --length_cutoff LENGTH_CUTOFF
93 | Minimum cutoff for read length (default: 0)
94 | --mod_t MOD_T Probability threshold for a per-read prediction to be considered modified. Only predictiond with probability >= mod_t will be considered as modified for
95 | calculation of per-site modification levels. (default: 0.5)
96 | --unmod_t UNMOD_T Probability threshold for a per-read prediction to be considered unmodified. Only predictiond with probability < unmod_t will be considered as unmodified for
97 | calculation of per-site modification levels. (default: 0.5)
98 | --include_non_cpg_ref
99 | Include non-CpG reference loci in per-site output where reads have CpG motif. (default: False)
100 | --input [INPUT ...] List of paths of per-read methylation calls to merge. File paths should be separated by space/whitespace. Use either --input or --list argument, but not both.
101 | (default: None)
102 | --list LIST A file containing paths to per-read methylation calls to merge (one per line). Use either --inputs or --list argument, but not both. (default: None)
103 | ```
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: deepmod2
2 | channels:
3 | - conda-forge
4 | - bioconda
5 | - defaults
6 | dependencies:
7 | - numpy
8 | - pysam
9 | - python>=3.8
10 | - pytorch
11 | - h5py
12 | - tqdm
13 | - numba
14 | - ont-fast5-api
15 | - pip
16 | - pip:
17 | - pod5
18 |
--------------------------------------------------------------------------------
/plot_utils/README.md:
--------------------------------------------------------------------------------
1 | # Signal Plotting
2 |
3 | The Jupyter notebook [Signal_Plot_Examples.ipynb](Signal_Plot_Examples.ipynb) uses the following data: [plot_files.tar.gz](https://github.com/WGLab/DeepMod2/files/14985308/plot_files.tar.gz)
4 |
--------------------------------------------------------------------------------
/plot_utils/plot.py:
--------------------------------------------------------------------------------
1 | import pod5 as p5
2 | import numpy as np
3 | from pathlib import Path
4 | import os, pysam
5 | import pandas as pd
6 | from scipy import stats
7 | import matplotlib.pyplot as plt
8 | import matplotlib as mpl
9 |
10 | import matplotlib.patches as mpatches
11 | from matplotlib.lines import Line2D
12 | import plotly.graph_objects as go
13 |
14 | comp_base_map={'A':'T','T':'A','C':'G','G':'C'}
15 | base_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3}
16 | rev_base_map={0:'A', 1:'C', 2:'G', 3:'T'}
17 | strand_map={'+':0, '-':1}
18 |
19 | def get_x_and_y_axes(split_signal):
20 | x_axes = []
21 | y_axes = []
22 | x_axis = np.array([])
23 | y_axis = []
24 |
25 | for i, base_signal in enumerate(split_signal):
26 | x_axis = np.hstack([x_axis, np.linspace(i+0.05, i+1-0.05, base_signal.shape[0])])
27 | y_axis.append(base_signal)
28 |
29 | return x_axis, np.hstack(y_axis)
30 |
31 |
32 | def plot_read(read_data, marker_transparency=0.8, line_plot=True, marker_size=20, lim=0, display_average=True, save_path=None):
33 | split_signals, seq=read_data['signal'], read_data['seq']
34 | if lim==0:
35 | u_lim=np.max(np.hstack(split_signals))
36 | l_lim=np.min(np.hstack(split_signals))
37 | else:
38 | u_lim, l_lim=lim, -1*lim
39 |
40 | K=len(seq)
41 |
42 | p_x_axis, p_y_axis=get_x_and_y_axes(split_signals)
43 |
44 | plt.figure(figsize=(12, 3),dpi=100)
45 |
46 | if line_plot:
47 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color='green')
48 | plt.scatter(p_x_axis, p_y_axis,alpha=marker_transparency, s=marker_size, color='green', edgecolor='black')
49 |
50 | plt.axhline(y=0, linestyle='dotted')
51 | for i in range(K):
52 | plt.axvline(x=i, linestyle='dotted', ymin=-4, ymax=4)
53 |
54 | handles=[]
55 | if display_average:
56 | for i in range(0,K):
57 | plt.hlines(y=np.mean(split_signals[i]),xmin=i,xmax=i+1,color='red',linestyle='-')
58 | plt.hlines(y=np.median(split_signals[i]),xmin=i,xmax=i+1,color='blue',linestyle='-')
59 |
60 | handles.append(mpatches.Patch(color='red', label='Mean Signal'))
61 | handles.append(mpatches.Patch(color='blue', label='Median Signal'))
62 |
63 | plt.axis([0, K, 1.1*l_lim, 1.1*u_lim])
64 | plt.xticks(np.arange(0, K) + 1/2, list(seq))
65 |
66 | handles.append(Line2D([0], [0], marker='o', color='white', label='Signal', alpha=0.5, markeredgecolor='black', markerfacecolor='green', markersize=5))
67 | leg=plt.legend(handles=handles,loc='upper right',fontsize=10)
68 | for patch in leg.get_patches():
69 | patch.set_height(5)
70 |
71 | plt.ylabel('Signal')
72 | plt.xlabel('Bases')
73 |
74 | if save_path!=None:
75 | plt.savefig(save_path, bbox_inches='tight')
76 |
77 | plt.show()
78 |
79 | def plot_single_sample(data, line_plot=False, color='green',lim=0, marker_size=10, marker_transparency=0.2, display_average=True, save_path=None):
80 | plt.figure(figsize=(12, 3),dpi=100)
81 |
82 | if lim==0:
83 | u_lim=max(np.max(np.hstack(x['signal'])) for x in data.values())
84 | l_lim=min(np.min(np.hstack(x['signal'])) for x in data.values())
85 | else:
86 | u_lim, l_lim=lim, -1*lim
87 |
88 |
89 | cons_seq=get_consensus(data)
90 | K=len(cons_seq)
91 |
92 | bin_list=[[] for i in range(K)]
93 |
94 | for read_data in data.values():
95 | for i in range(K):
96 | for signal in read_data['signal'][i]:
97 | bin_list[i].append(signal)
98 |
99 | p_x_axis, p_y_axis=get_x_and_y_axes(read_data['signal'])
100 |
101 |
102 | if line_plot:
103 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color=color)
104 |
105 | plt.scatter(p_x_axis, p_y_axis, alpha=marker_transparency, s=marker_size, color=color, edgecolor=None)
106 |
107 | plt.axhline(y=0, linestyle='dotted')
108 | for i in range(K):
109 | plt.axvline(x=i, linestyle='dotted', ymin=-4, ymax=4)
110 |
111 | handles=[]
112 | if display_average:
113 | for i in range(0,K):
114 | plt.hlines(y=np.mean(bin_list[i]),xmin=i,xmax=i+1,color='black',linestyle='-')
115 | plt.hlines(y=np.median(bin_list[i]),xmin=i,xmax=i+1,color='blue',linestyle='-')
116 | handles.append(mpatches.Patch(color='black', label='Mean Signal'))
117 | handles.append(mpatches.Patch(color='blue', label='Median Signal'))
118 |
119 | plt.axis([0, K, 1.1*l_lim, 1.1*u_lim])
120 | plt.xticks(np.arange(0, K) + 1/2, cons_seq)
121 |
122 | handles.append(Line2D([0], [0], marker='o', color='white', label='Signal', alpha=0.5, markeredgecolor='black', markerfacecolor=color, markersize=5))
123 | leg=plt.legend(handles=handles,loc='upper right',fontsize=10)
124 | for patch in leg.get_patches():
125 | patch.set_height(5)
126 | plt.ylabel('Signal')
127 | plt.xlabel('Bases')
128 |
129 | if save_path!=None:
130 | plt.savefig(save_path,bbox_inches='tight')
131 |
132 | plt.show()
133 |
134 | def plot_two_samples(sample1_data, sample2_data, label1='Mod', label2='Unmod', line_plot=False, lim=0, marker_size=10, marker_transparency=0.2, display_average=True, average_type='median', save_path=None, color1='green', color2='red'):
135 |
136 | plt.figure(figsize=(12, 3),dpi=100)
137 |
138 | tmp_dict=dict(sample1_data)
139 | tmp_dict.update(sample2_data)
140 | cons_seq=get_consensus(tmp_dict)
141 |
142 | if lim==0:
143 | u_lim=max(np.max(np.hstack(x['signal'])) for x in tmp_dict.values())
144 | l_lim=min(np.min(np.hstack(x['signal'])) for x in tmp_dict.values())
145 | else:
146 | u_lim, l_lim=lim, -1*lim
147 |
148 | K=len(cons_seq)
149 |
150 | sample_1_bin_list=[[] for i in range(K)]
151 | sample_2_bin_list=[[] for i in range(K)]
152 |
153 |
154 | for read_data in sample1_data.values():
155 | for i in range(K):
156 | for signal in read_data['signal'][i]:
157 | sample_1_bin_list[i].append(signal)
158 |
159 | p_x_axis, p_y_axis=get_x_and_y_axes(read_data['signal'])
160 |
161 |
162 | if line_plot:
163 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color=color1)
164 |
165 | plt.scatter(p_x_axis, p_y_axis, alpha=marker_transparency, s=marker_size, color=color1, edgecolor=None)
166 |
167 | for read_data in sample2_data.values():
168 | for i in range(K):
169 | for signal in read_data['signal'][i]:
170 | sample_2_bin_list[i].append(signal)
171 |
172 | p_x_axis, p_y_axis=get_x_and_y_axes(read_data['signal'])
173 |
174 |
175 | if line_plot:
176 | plt.plot(p_x_axis, p_y_axis, linewidth=0.5, alpha=0.5, color=color2)
177 |
178 | plt.scatter(p_x_axis, p_y_axis, alpha=marker_transparency, s=marker_size, color=color2, edgecolor=None)
179 |
180 | plt.axhline(y=0, linestyle='dotted')
181 | for i in range(K):
182 | plt.axvline(x=i, linestyle='dotted', ymin=-4, ymax=4)
183 |
184 | handles=[]
185 | if display_average:
186 | for i in range(0,K):
187 | if average_type=='mean':
188 | plt.hlines(y=np.mean(sample_1_bin_list[i]),xmin=i,xmax=i+1,color=color1,linestyle='-')
189 | plt.hlines(y=np.mean(sample_2_bin_list[i]),xmin=i,xmax=i+1,color=color2,linestyle='-')
190 |
191 | elif average_type=='median':
192 | plt.hlines(y=np.median(sample_1_bin_list[i]),xmin=i,xmax=i+1,color=color1,linestyle='-')
193 | plt.hlines(y=np.median(sample_2_bin_list[i]),xmin=i,xmax=i+1,color=color2,linestyle='-')
194 |
195 | handles.append(mpatches.Patch(color=color1, label='{} {} Signal'.format(label1, average_type.capitalize())))
196 | handles.append(mpatches.Patch(color=color2, label='{} {} Signal'.format(label2, average_type.capitalize())))
197 |
198 | plt.axis([0, K, 1.1*l_lim, 1.1*u_lim])
199 | plt.xticks(np.arange(0, K) + 1/2, cons_seq)
200 |
201 | handles.append(Line2D([0], [0], marker='o', color='white', label='{} Signal'.format(label1), alpha=0.5, markeredgecolor='black', markerfacecolor='green', markersize=5))
202 | handles.append(Line2D([0], [0], marker='o', color='white', label='{} Signal'.format(label2), alpha=0.5, markeredgecolor='black', markerfacecolor='red', markersize=5))
203 | leg=plt.legend(handles=handles,loc='upper right',fontsize=10)
204 |
205 | for patch in leg.get_patches():
206 | patch.set_height(5)
207 | plt.ylabel('Signal')
208 | plt.xlabel('Bases')
209 |
210 | if save_path!=None:
211 | plt.savefig(save_path,bbox_inches='tight')
212 |
213 | plt.show()
214 |
215 | def violin_plot(sample_data, avg_type='median', static_display=False, meanline_visible=False, figure_width=1000, figure_height=500, save_path=None):
216 |
217 | cons_seq=get_consensus(sample_data)
218 |
219 | K=len(cons_seq)
220 | d=[]
221 |
222 |
223 | j=0
224 | for i in range(K):
225 | for read_data in sample_data.values():
226 | if avg_type=='mean':
227 | d.append([j,np.mean(read_data['signal'][i])])
228 | elif avg_type=='median':
229 | d.append([j,np.median(read_data['signal'][i])])
230 | j+=1
231 |
232 | df=pd.DataFrame(d)
233 | df.rename(columns={0:'Position',1:'Signal'}, inplace=True)
234 |
235 | fig = go.Figure()
236 |
237 | fig.add_trace(go.Violin(x=df['Position'],
238 | y=df['Signal'],
239 | points=False,
240 | line=dict(color="blue", width=0.5),meanline=dict(color="blue", width=2),
241 | meanline_visible=meanline_visible)
242 | )
243 |
244 | tickvals=np.arange(0,len(cons_seq))
245 | ticktext=['%d
%s' %(a-len(cons_seq)//2,b) for a,b in zip(tickvals, cons_seq)]
246 | fig.update_xaxes(tickmode='array', tickvals=tickvals, ticktext=ticktext)
247 | fig.update_layout(violingap=0, violinmode='overlay')
248 |
249 |
250 |
251 | if static_display:
252 | fig.update_layout(autosize=False, width=figure_width, height=figure_height)
253 | if save_path!=None:
254 | fig.write_image(save_path)
255 | fig.show(renderer="svg")
256 | else:
257 | if save_path!=None:
258 | fig.write_html(save_path)
259 | fig.show()
260 |
261 | return
262 |
263 |
264 | def compare_violin_plot(sample1_data, sample2_data, label1='Mod', label2='Unmod', avg_type='median', static_display=False, meanline_visible=False, figure_width=1000, figure_height=500, save_path=None, test_type='mw', test_method="auto", display_pval=True):
265 | tmp_dict=dict(sample1_data)
266 | tmp_dict.update(sample2_data)
267 | cons_seq=get_consensus(tmp_dict)
268 |
269 | K=len(cons_seq)
270 | d=[]
271 |
272 |
273 | j=0
274 | for i in range(K):
275 | for read_data in sample1_data.values():
276 | if avg_type=='mean':
277 | d.append([j,np.mean(read_data['signal'][i])])
278 | elif avg_type=='median':
279 | d.append([j,np.median(read_data['signal'][i])])
280 | j+=1
281 |
282 | pos_df=pd.DataFrame(d)
283 | pos_df.rename(columns={0:'Position',1:'Signal'}, inplace=True)
284 | pos_df['Sample']=label1
285 |
286 | d=[]
287 |
288 |
289 | j=0
290 | for i in range(K):
291 | for read_data in sample2_data.values():
292 | if avg_type=='mean':
293 | d.append([j,np.mean(read_data['signal'][i])])
294 | elif avg_type=='median':
295 | d.append([j,np.median(read_data['signal'][i])])
296 | j+=1
297 |
298 | neg_df=pd.DataFrame(d)
299 | neg_df.rename(columns={0:'Position',1:'Signal'}, inplace=True)
300 | neg_df['Sample']=label2
301 |
302 | df=pd.concat([pos_df, neg_df])
303 | fig = go.Figure()
304 |
305 | fig.add_trace(go.Violin(x=df['Position'][ df['Sample'] == label1],
306 | y=df['Signal'][ df['Sample'] == label1 ],
307 | legendgroup=label1, scalegroup=label1, name=label1, points=False,
308 | side='negative',
309 | line=dict(color="blue", width=0.5),meanline=dict(color="blue", width=2),
310 | meanline_visible=meanline_visible)
311 | )
312 | fig.add_trace(go.Violin(x=df['Position'][ df['Sample'] == label2],
313 | y=df['Signal'][ df['Sample'] == label2 ],
314 | legendgroup=label2, scalegroup=label2, name=label2, points=False,
315 | side='positive',
316 | line=dict(color="orange", width=0.5), meanline=dict(color="orange", width=2),
317 | meanline_visible=meanline_visible)
318 | )
319 |
320 |
321 | group_df=df.groupby(['Position', 'Sample'])['Signal'].apply(list)
322 | dist_stats={}
323 | for i in range(len(group_df)//2):
324 | if test_type=='ks':
325 | s=stats.ks_2samp(group_df.loc[i,:].loc[label1], group_df.loc[i,:].loc[label2], method=test_method)
326 | elif test_type=="mw":
327 | s=stats.mannwhitneyu(group_df.loc[i,:].loc[label1], group_df.loc[i,:].loc[label2], method=test_method)
328 | dist_stats[i]={'Position':i,'Base':cons_seq[i],'Statistic':s.statistic, 'Pvalue':s.pvalue}
329 |
330 | dist_stats_df=pd.DataFrame(dist_stats).T
331 | dist_stats_df=dist_stats_df.astype({'Position': 'int', 'Statistic': 'float32', 'Pvalue': 'float32',})
332 |
333 | tickvals=np.arange(0,len(cons_seq))
334 | if display_pval:
335 | ticktext=['{}
{}
{:0.1e}'.format(a-len(cons_seq)//2,b,c) for a,b,c in zip(tickvals, cons_seq, dist_stats_df.Pvalue)]
336 | else:
337 | ticktext=['%d
%s' %(a-len(cons_seq)//2,b) for a,b in zip(tickvals, cons_seq)]
338 | fig.update_xaxes(tickmode='array', tickvals=tickvals, ticktext=ticktext,tickfont = dict(size = 8))
339 | fig.update_layout(violingap=0, violinmode='overlay')
340 |
341 | if static_display:
342 | fig.update_layout(autosize=False, width=figure_width, height=figure_height)
343 | if save_path!=None:
344 | fig.write_image(save_path)
345 | fig.show(renderer="svg")
346 | else:
347 | if save_path!=None:
348 | fig.write_html(save_path)
349 | fig.show()
350 |
351 | return dist_stats_df
352 |
353 |
354 | def revcomp(s):
355 | return ''.join(comp_base_map[x] for x in s[::-1])
356 |
357 | def get_consensus(data):
358 | K=len(next(iter(data.values()))['seq'])
359 | cons_seq_array=np.zeros((K,4))
360 | for x in data.values():
361 | for i in range(K):
362 | cons_seq_array[i][base_map[x['seq'][i]]]+=1
363 |
364 | cons_seq=''.join(rev_base_map[t] for t in np.argmax(cons_seq_array,axis=1))
365 | return cons_seq
366 |
367 | def get_file_names(base_path):
368 | read_filename_dict={}
369 |
370 | if os.path.isdir(base_path):
371 | files=Path(base_path).rglob('*.pod5')
372 | else:
373 | files=[base_path]
374 | for read_path in files:
375 | read_path=str(read_path)
376 | with p5.Reader(read_path) as reader:
377 | for rname in reader.read_ids:
378 | read_filename_dict[rname]=read_path
379 |
380 | return read_filename_dict
381 |
382 | def get_read_positions(bam_path, chrom, pos, strand, seq_type):
383 | flag=0x4|0x100|0x200|0x400|0x800
384 |
385 | read_info={}
386 |
387 | bam=pysam.Samfile(bam_path,'rb')
388 |
389 | for pcol in bam.pileup(contig=chrom, start=pos-1, end=pos, flag_filter=flag, truncate=True, min_base_quality = 0):
390 | if strand=='+':
391 | for read in pcol.pileups:
392 | if read.alignment.is_reverse==False:
393 | if read.is_del:
394 | continue
395 | print('DEL', read.alignment.qname)
396 | else:
397 | if seq_type=='dna':
398 | read_info[read.alignment.qname]=(read.query_position, False, read.alignment.to_dict())
399 | elif seq_type=='rna':
400 | read_info[read.alignment.qname]=(read.alignment.query_length-read.query_position-1, False, read.alignment.to_dict())
401 | elif strand=='-':
402 | for read in pcol.pileups:
403 | if read.alignment.is_reverse:
404 | if read.is_del:
405 | continue
406 | print('DEL', read.alignment.qname)
407 | else:
408 | if seq_type=='dna':
409 | read_info[read.alignment.qname]=(read.alignment.query_length-read.query_position-1, True, read.alignment.to_dict())
410 |
411 | elif seq_type=='rna':
412 | read_info[read.alignment.qname]=(read.query_position, True, read.alignment.to_dict())
413 |
414 | return read_info
415 |
416 | def get_read_signal_raw(signal, move,norm_type):
417 | stride, start, move_table=move
418 | median=np.median(signal)
419 | mad=np.median(np.abs(signal-median))
420 |
421 | if norm_type=='MAD':
422 | norm_signal=(signal-median)/mad
423 | elif norm_type=='STD':
424 | norm_signal=(signal-np.mean(signal))/np.std(signal)
425 |
426 | move_len=len(move_table)
427 | move_index=np.where(move_table)[0]
428 | rlen=len(move_index)
429 |
430 | base_level_data=[]
431 | for i in range(len(move_index)-1):
432 | prev=move_index[i]*stride+start
433 | sig_end=move_index[i+1]*stride+start
434 | base_level_data.append([prev,sig_end-prev])
435 |
436 | return norm_signal, base_level_data
437 |
438 | def get_signals(bam_path, chrom, pos, strand, read_filename_dict, base_path, seq_type, max_cov=1000, window_before=10, window_after=10, norm_type='STD'):
439 | read_info=get_read_positions(bam_path, chrom, pos, strand,seq_type)
440 | data={}
441 |
442 | if seq_type=='rna':
443 | window_before, window_after=window_after, window_before
444 |
445 | cov=0
446 | for read_name in read_info.keys():
447 | if cov > max_cov:
448 | break
449 | try:
450 | read_path=read_filename_dict[read_name]
451 | except KeyError:
452 | continue
453 |
454 | with p5.Reader(read_path) as reader:
455 | raw_read=next(reader.reads([read_name]))
456 | try:
457 | read_pos, reverse, read_dict=read_info[read_name]
458 |
459 | except KeyError:
460 | continue
461 |
462 | tags={x.split(':')[0]:x for x in read_dict['tags']}
463 | start=int(tags['ts'].split(':')[-1])
464 | mv=tags['mv'].split(',')
465 |
466 | stride=int(mv[1])
467 | move_table=np.fromiter(mv[2:], dtype=np.int8)
468 | move=(stride, start, move_table)
469 |
470 | signal=raw_read.signal
471 |
472 | fq=read_dict['seq']
473 | fq=revcomp(fq) if reverse else fq
474 |
475 | if seq_type=='rna':
476 | fq=fq[::-1]
477 |
478 | norm_signal, base_level_data = get_read_signal_raw(signal, move, norm_type)
479 | seq_len=len(fq)
480 |
481 | if read_pos>window_before+5 and read_pos5]=5
39 | signal[signal<-5]=-5
40 |
41 | move_len=len(move_table)
42 |
43 | data=np.zeros((move_len,stride+2))
44 | indexes=np.full(move_len,fill_value=0,dtype=np.int32)
45 | z=1
46 | idx=-1
47 |
48 | segments=np.full(np.sum(move_table),fill_value=0,dtype=np.int32)
49 |
50 | for i in range(move_len):
51 | if move_table[i]:
52 | z=z^1
53 | idx+=1
54 | segments[idx]=i
55 |
56 | data[i,z]=1
57 | indexes[i]=idx
58 | for k in range(stride):
59 | data[i,2+k]=signal[start+i*stride+k]
60 |
61 | return data, indexes, segments
62 |
63 | @jit(nopython=True)
64 | def get_events(signal, move, norm_type):
65 | stride, start, move_table=move
66 |
67 | if norm_type=='mad':
68 | median=np.median(signal)
69 | mad=np.median(np.abs(signal-median))
70 | signal=(signal-median)/mad
71 |
72 | else:
73 | mean=np.mean(signal)
74 | std=np.std(signal)
75 | signal=(signal-mean)/std
76 |
77 | move_len=len(move_table)
78 | move_index=np.where(move_table)[0]
79 | rlen=len(move_index)
80 |
81 | data=np.zeros((rlen,9))
82 |
83 | for i in range(len(move_index)-1):
84 | prev=move_index[i]*stride+start
85 | sig_end=move_index[i+1]*stride+start
86 |
87 | sig_len=sig_end-prev
88 | data[i, 8]=np.log10(sig_len)
89 | data[i, 4]=np.median(signal[prev:sig_end])
90 | data[i, 5]=np.median(np.abs(signal[prev:sig_end]-data[i, 4]))
91 | data[i, 6]=np.mean(signal[prev:sig_end])
92 | data[i, 7]=np.std(signal[prev:sig_end])
93 |
94 | for j in range(4):
95 | tmp_cnt=0
96 | for t in range(j*sig_len//4,min(sig_len, (j+1)*sig_len//4)):
97 | data[i, j]+=signal[t+prev]
98 | tmp_cnt+=1
99 | data[i, j]=data[i, j]/tmp_cnt
100 |
101 | return data
102 |
103 | def get_candidates(read_seq, align_data, aligned_pairs, ref_pos_dict, exp_motif_seq, motif_base, motif_ind, position_based):
104 | if align_data[0]:
105 | is_mapped, is_forward, ref_name, reference_start, reference_end, read_length=align_data
106 |
107 | base_id={m.start(0):i for i,m in enumerate(re.finditer(r'(?={})'.format(motif_base), read_seq))}
108 |
109 | motif_anchor=np.array([m.start(0) for m in re.finditer(r'(?={})'.format(exp_motif_seq), read_seq)])
110 | motif_id=np.array(sorted(list(set.union(*[set(motif_anchor+i) for i in motif_ind]))))
111 |
112 | ref_motif_pos=ref_pos_dict[ref_name][0] if is_forward else ref_pos_dict[ref_name][1]
113 |
114 | common_pos=ref_motif_pos[(ref_motif_pos>=reference_start)&(ref_motif_pos0 and not position_based:
127 | aligned_pairs_read_wise=aligned_pairs_read_wise[motif_id]
128 |
129 | #if need to disable clipped bases
130 | #aligned_pairs_read_wise=aligned_pairs_read_wise[(reference_start<=aligned_pairs_read_wise[:,1]) & (aligned_pairs_read_wise[:,1]counter_check:
202 | counter_check=counter//10000
203 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True)
204 | if res[0]:
205 | _, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list = res
206 | for read_data, candidate_list, MM, ML, pred_list in zip(*res[1:]):
207 | counter+=1
208 | read_dict, read_info = read_data
209 | read=pysam.AlignedSegment.from_dict(read_dict,header)
210 | if MM:
211 | read.set_tag('MM',MM,value_type='Z')
212 | read.set_tag('ML',ML)
213 |
214 | outf.write(read)
215 |
216 | read_name=read_dict['name']
217 | is_forward, chrom, read_length, mean_qscore=read_info
218 | chrom=chrom if chrom else 'NA'
219 |
220 | strand='+' if is_forward else '-'
221 |
222 | phase=0
223 | phase=read.get_tag('HP') if read.has_tag('HP') else 0
224 |
225 | idx=4*phase+2*is_forward
226 |
227 | if float(mean_qscore)unmod_threshold:
249 | pass
250 | elif not skip_per_site:
251 | mod=score>=mod_threshold
252 |
253 | if (chrom, zero_based_fwd_pos) not in per_site_pred:
254 | per_site_pred[(chrom, zero_based_fwd_pos)]=[0]*12+[is_ref_cpg]
255 |
256 | per_site_pred[(chrom, zero_based_fwd_pos)][idx+mod]+=1
257 |
258 | per_read_file.write('%s\t%s\t%s\t%s\t%d\t%s\t%.4f\t%.2f\t%d\t%d\t%s\n' %(read_name, chrom, ref_pos_str_before, ref_pos_str_after, read_pos, strand, score, mean_qscore, read_length, phase, is_ref_cpg))
259 |
260 |
261 | else:
262 | _, total_read_info=res
263 | for read_dict in total_read_info:
264 | counter+=1
265 | read=pysam.AlignedSegment.from_dict(read_dict,header)
266 | outf.write(read)
267 |
268 | except queue.Empty:
269 | pass
270 |
271 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True)
272 | print('%s: Finished Per-Read Methylation Output. Starting Per-Site output.' %str(datetime.datetime.now()), flush=True)
273 | print('%s: Modification Tagged BAM file: %s' %(str(datetime.datetime.now()),bam_output), flush=True)
274 | print('%s: Per Read Prediction file: %s' %(str(datetime.datetime.now()), per_read_file_path), flush=True)
275 | print('%s: Writing Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True)
276 |
277 | if skip_per_site:
278 | return
279 | per_site_fields=['#chromosome', 'position_before', 'position','strand', 'ref_cpg',
280 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction',
281 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1',
282 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2']
283 | per_site_header='\t'.join(per_site_fields)+'\n'
284 | per_site_fields.remove('strand')
285 | agg_per_site_header='\t'.join(per_site_fields)+'\n'
286 |
287 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix'])
288 | agg_per_site_file_path=os.path.join(params['output'],'%s.per_site.aggregated' %params['prefix'])
289 |
290 | with open(per_site_file_path, 'w') as per_site_file, open(agg_per_site_file_path,'w') as agg_per_site_file:
291 | per_site_file.write(per_site_header)
292 | agg_per_site_file.write(agg_per_site_header)
293 |
294 | for x in sorted(per_site_pred.keys()):
295 | chrom, pos=x
296 | cpg=per_site_pred[x]
297 | is_ref_cpg=cpg[12]
298 |
299 | if cpg_ref_only and is_ref_cpg==False:
300 | continue
301 | #fwd_stats=[self.chrom, self.position, self.position+1, '+', self.is_ref_cpg]+self.get_all_phases().forward.stats() + self.phase_1.forward.stats() + self.phase_2.forward.stats()
302 |
303 | agg_stats, fwd_stats, rev_stats=get_stats_string_cpg(chrom, pos, is_ref_cpg, cpg)
304 | if agg_stats[0]>0:
305 | agg_per_site_file.write(agg_stats[1])
306 |
307 | if fwd_stats[0]>0:
308 | per_site_file.write(fwd_stats[1])
309 |
310 | if rev_stats[0]>0:
311 | per_site_file.write(rev_stats[1])
312 |
313 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True)
314 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True)
315 | print('%s: Aggregated Per Site Prediction file: %s' %(str(datetime.datetime.now()), agg_per_site_file_path), flush=True)
316 |
317 | return
318 |
319 | def get_output(params, output_Q, methylation_event, header_dict, ref_pos_dict):
320 | header=pysam.AlignmentHeader.from_dict(header_dict)
321 |
322 | bam_threads=params['bam_threads']
323 |
324 | output=params['output']
325 | bam_output=os.path.join(output,'%s.bam' %params['prefix'])
326 | per_read_file_path=os.path.join(output,'%s.per_read' %params['prefix'])
327 |
328 |
329 | per_site_file_path=os.path.join(output,'%s.per_site' %params['prefix'])
330 | agg_per_site_file_path=os.path.join(output,'%s.per_site.aggregated' %params['prefix'])
331 | qscore_cutoff=params['qscore_cutoff']
332 | length_cutoff=params['length_cutoff']
333 |
334 | mod_threshold=params['mod_t']
335 | unmod_threshold=params['unmod_t']
336 |
337 | skip_per_site=params['skip_per_site']
338 |
339 | per_site_pred={}
340 |
341 | counter=0
342 |
343 | counter_check=0
344 | with open(per_read_file_path,'w') as per_read_file:
345 | per_read_file.write('read_name\tchromosome\tref_position_before\tref_position\tread_position\tstrand\tmethylation_score\tmean_read_qscore\tread_length\tread_phase\n')
346 |
347 | with pysam.AlignmentFile(bam_output, "wb", threads=bam_threads, header=header) as outf:
348 | while True:
349 | if methylation_event.is_set() and output_Q.empty():
350 | break
351 | else:
352 | try:
353 | res = output_Q.get(block=False, timeout=10)
354 | #continue
355 | if counter//10000>counter_check:
356 | counter_check=counter//10000
357 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True)
358 | if res[0]:
359 | _, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list = res
360 | for read_data, candidate_list, MM, ML, pred_list in zip(*res[1:]):
361 | counter+=1
362 | read_dict, read_info = read_data
363 | read=pysam.AlignedSegment.from_dict(read_dict,header)
364 | if MM:
365 | read.set_tag('MM',MM,value_type='Z')
366 | read.set_tag('ML',ML)
367 |
368 | outf.write(read)
369 |
370 | read_name=read_dict['name']
371 | is_forward, chrom, read_length, mean_qscore=read_info
372 | chrom=chrom if chrom else 'NA'
373 |
374 | strand='+' if is_forward else '-'
375 |
376 | phase=0
377 | phase=read.get_tag('HP') if read.has_tag('HP') else 0
378 |
379 |
380 | if float(mean_qscore)unmod_threshold:
396 | pass
397 | elif not skip_per_site:
398 | mod=score>=mod_threshold
399 |
400 | if (chrom, ref_pos,strand) not in per_site_pred:
401 | per_site_pred[(chrom, ref_pos,strand)]=[0]*6
402 |
403 | per_site_pred[(chrom, ref_pos,strand)][2*phase+mod]+=1
404 |
405 | per_read_file.write('%s\t%s\t%s\t%s\t%d\t%s\t%.4f\t%.2f\t%d\t%d\n' %(read_name, chrom, ref_pos_str_before, ref_pos_str_after, read_pos, strand, score, mean_qscore, read_length, phase))
406 |
407 |
408 | else:
409 | _, total_read_info=res
410 | for read_dict in total_read_info:
411 | counter+=1
412 | read=pysam.AlignedSegment.from_dict(read_dict,header)
413 | outf.write(read)
414 |
415 | except queue.Empty:
416 | pass
417 |
418 | print('%s: Number of reads processed: %d' %(str(datetime.datetime.now()), counter), flush=True)
419 | print('%s: Finished Per-Read Methylation Output. Starting Per-Site output.' %str(datetime.datetime.now()), flush=True)
420 | print('%s: Modification Tagged BAM file: %s' %(str(datetime.datetime.now()),bam_output), flush=True)
421 | print('%s: Per Read Prediction file: %s' %(str(datetime.datetime.now()), per_read_file_path), flush=True)
422 |
423 | if skip_per_site:
424 | return
425 | per_site_fields=['#chromosome', 'position_before', 'position','strand',
426 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction',
427 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1',
428 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2']
429 | per_site_header='\t'.join(per_site_fields)+'\n'
430 |
431 |
432 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix'])
433 |
434 | with open(per_site_file_path, 'w') as per_site_file:
435 | per_site_file.write(per_site_header)
436 |
437 | for x in sorted(per_site_pred.keys()):
438 | chrom, pos, strand=x
439 | mod_call=per_site_pred[x]
440 |
441 | stats=get_stats_string(chrom, pos, strand, mod_call)
442 | if stats[0]>0:
443 | per_site_file.write(stats[1])
444 |
445 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True)
446 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True)
447 | return
448 |
449 | def process(params,ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict):
450 | torch.set_grad_enabled(False);
451 |
452 | dev=params['dev']
453 | motif_seq=params['motif_seq']
454 | exp_motif_seq=params['exp_motif_seq']
455 | motif_base=motif_seq[params['motif_ind'][0]]
456 | motif_ind=params['motif_ind']
457 |
458 | if params['mod_symbol']:
459 | mod_symbol=params['mod_symbol']
460 | elif motif_seq=='CG':
461 | mod_symbol='m'
462 | else:
463 | mod_symbol=motif_base
464 |
465 | seq_type=params['seq_type']
466 |
467 | position_based=params['position_based']
468 |
469 | base_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3}
470 |
471 | cigar_map={'M':0, '=':0, 'X':0, 'D':1, 'I':2, 'S':2,'H':2, 'N':1, 'P':4, 'B':4}
472 | cigar_pattern = r'\d+[A-Za-z=]'
473 |
474 | model, model_config=get_model(params)
475 | window=model_config['window']
476 | full_signal=model_config['full_signal']
477 | strides_per_base=model_config['strides_per_base']
478 | norm_type=model_config['norm_type']
479 |
480 | model.eval()
481 | model.to(dev);
482 | print('CUDA test:', torch.cuda.is_available(), next(model.parameters()).is_cuda, flush=True)
483 |
484 | reads_per_round=100
485 |
486 | chunk_size=256 if dev=='cpu' else params['batch_size']
487 |
488 | total_candidate_list=[]
489 | total_feature_list=[]
490 | total_base_seq_list=[]
491 | total_MM_list=[]
492 | total_read_info=[]
493 | total_c_idx=[]
494 | total_unprocessed_reads=[]
495 | total_ref_seq_list=[]
496 | r_count=0
497 |
498 | dummy_ref_seq=4+np.zeros(1000000)
499 |
500 | ref_available=True if params['ref'] else False
501 |
502 | while True:
503 | if (signal_Q.empty() and input_event.is_set()):
504 | break
505 |
506 | try:
507 |
508 | chunk=signal_Q.get(block=False, timeout=10)
509 | #print('%s: Output_qsize=%d Signal_qsize=%d' %(str(datetime.datetime.now()), output_Q.qsize(), signal_Q.qsize()),flush=True)
510 | if output_Q.qsize()>200:
511 | time.sleep(30)
512 | if output_Q.qsize()>500:
513 | time.sleep(60)
514 | print('Pausing output due to queue size limit. Output_qsize=%d Signal_qsize=%d' %(output_Q.qsize(), signal_Q.qsize()), flush=True)
515 | for data in chunk:
516 | signal, move, read_dict, align_data=data
517 |
518 | is_mapped, is_forward, ref_name, reference_start, reference_end, read_length=align_data
519 |
520 | fq=read_dict['seq']
521 | qual=read_dict['qual']
522 | sequence_length=len(fq)
523 | reverse= not is_forward
524 | fq=revcomp(fq) if reverse else fq
525 | qual=qual[::-1] if reverse else qual
526 |
527 | if is_mapped and ref_available:
528 | cigar_tuples = np.array([(int(x[:-1]), cigar_map[x[-1]]) for x in re.findall(cigar_pattern, read_dict['cigar'])])
529 | ref_start=int(read_dict['ref_pos'])-1
530 | aligned_pairs=get_aligned_pairs(cigar_tuples, ref_start)
531 | else:
532 | aligned_pairs=None
533 |
534 |
535 | pos_list_c, pos_list_candidates, read_to_ref_pairs=get_candidates(fq, align_data, aligned_pairs, ref_pos_dict, exp_motif_seq, motif_base, motif_ind, position_based)
536 |
537 | pos_list_candidates=pos_list_candidates[(pos_list_candidates[:,0]>window*strides_per_base)\
538 | &(pos_list_candidates[:,0]0 else pos_list_candidates
539 |
540 | if len(pos_list_candidates)==0:
541 | total_unprocessed_reads.append(read_dict)
542 | continue
543 |
544 | if not move[0]:
545 | try:
546 | tags={x.split(':')[0]:x for x in read_dict.pop('tags')}
547 | start=int(tags['ts'].split(':')[-1])
548 | mv=tags['mv'].split(',')
549 |
550 | stride=int(mv[1])
551 | move_table=np.fromiter(mv[2:], dtype=np.int8)
552 | move=(stride, start, move_table)
553 | read_dict['tags']=[x for x in tags.values() if x[:2] not in ['mv', 'ts', 'ML', 'MM']]
554 | except KeyError:
555 | print('Read:%s no move table or stride or signal start found' %read_dict['name'])
556 | total_unprocessed_reads.append(read_dict)
557 | continue
558 |
559 | base_seq=np.array([base_map[x] for x in fq])
560 | base_qual=10**((33-np.array([ord(x) for x in qual]))/10)
561 | mean_qscore=-10*np.log10(np.mean(base_qual))
562 | base_qual=(1-base_qual)[:,np.newaxis]
563 |
564 | if is_mapped and ref_available and not params['exclude_ref_features']:
565 | ref_seq=ref_seq_dict[ref_name][:,1][read_to_ref_pairs[:, 1]][::-1] if reverse else ref_seq_dict[ref_name][:,0][read_to_ref_pairs[:, 1]]
566 |
567 | else:
568 | ref_seq=dummy_ref_seq
569 |
570 |
571 | if full_signal:
572 | mat, indexes, segments=get_segment_events(signal, move, norm_type)
573 | segments_ends=np.concatenate([segments[1:],np.array([len(move_table)])])
574 | segment_ranges=np.vstack([segments, segments_ends]).T
575 | if seq_type=='rna':
576 | mat=np.flip(mat,axis=0)
577 | indexes=len(read_dict['seq'])-np.flip(indexes,axis=0)-1
578 | segment_ranges=np.flip(np.flip(len(move_table)-segment_ranges,axis=1),axis=0)
579 | else:
580 | mat=get_events(signal, move, norm_type)
581 | if seq_type=='rna':
582 | mat=np.flip(mat,axis=0)
583 | mat=np.hstack((mat, base_qual))
584 |
585 | try:
586 | c_idx=[True if x in pos_list_c else False for x in pos_list_candidates[:,0]]
587 | c_idx_count=np.vectorize(pos_list_c.get)(pos_list_candidates[c_idx,0])
588 | c_idx_count[1:]=c_idx_count[1:]-c_idx_count[:-1]-1
589 | MM='{}+{}?,'.format(motif_base,mod_symbol)+','.join(c_idx_count.astype(str))+';'
590 | total_c_idx.append(c_idx)
591 | total_MM_list.append(MM)
592 |
593 | except ValueError:
594 | total_c_idx.append([])
595 | total_MM_list.append(None)
596 |
597 |
598 | if full_signal:
599 | base_seq_full=base_seq[indexes]
600 | ref_seq_full=ref_seq[indexes]
601 | base_qual_full=base_qual[indexes]
602 | per_site_features=np.array([mat[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates])
603 | per_site_base_qual=np.array([base_qual_full[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates])
604 | per_site_indexes=np.array([(indexes==candidate[0])[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates])
605 | per_site_features=np.dstack([per_site_features, per_site_indexes[:,:,np.newaxis], per_site_base_qual])
606 |
607 | per_site_base_seq=np.array([base_seq_full[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates])
608 | per_site_ref_seq=np.array([ref_seq_full[segment_ranges[candidate[0]][0]-strides_per_base*window:segment_ranges[candidate[0]][0]+strides_per_base*(window+1)] for candidate in pos_list_candidates])
609 |
610 |
611 | else:
612 | per_site_features=np.array([mat[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
613 | per_site_base_seq=np.array([base_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
614 | per_site_ref_seq=np.array([ref_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
615 |
616 | total_candidate_list.append(pos_list_candidates)
617 | total_feature_list.append(per_site_features)
618 | total_base_seq_list.append(per_site_base_seq)
619 | total_ref_seq_list.append(per_site_ref_seq)
620 |
621 | total_read_info.append((read_dict, [align_data[1],align_data[2],align_data[5], mean_qscore]))
622 |
623 | if len(total_read_info)>=reads_per_round:
624 | read_counts=np.cumsum([len(x) for x in total_feature_list])[:-1]
625 | features_list=np.vstack(total_feature_list)
626 | base_seq_list=np.vstack(total_base_seq_list)
627 | ref_seq_list=np.vstack(total_ref_seq_list)
628 |
629 | pred_list=[model(batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev)).cpu().numpy() for batch_x, batch_base_seq, batch_ref_seq in generate_batches(features_list, base_seq_list, window, ref_seq=ref_seq_list, batch_size = chunk_size)]
630 |
631 |
632 | pred_list=np.vstack(pred_list)
633 | pred_list=np.split(pred_list.ravel(), read_counts)
634 | read_qual_list=[array.array('B',np.round(255*read_pred_list[c_idx]).astype(int)) for read_pred_list, c_idx in zip(pred_list, total_c_idx)]
635 |
636 | output_Q.put([True, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list])
637 | total_candidate_list, total_feature_list, total_base_seq_list, total_MM_list, total_read_info, total_c_idx=[], [], [], [], [], []
638 | total_ref_seq_list=[]
639 |
640 | if len(total_unprocessed_reads)>100:
641 | output_Q.put([False, total_unprocessed_reads])
642 | total_unprocessed_reads=[]
643 |
644 | except queue.Empty:
645 | pass
646 |
647 | if len(total_read_info)>0:
648 | read_counts=np.cumsum([len(x) for x in total_feature_list])[:-1]
649 | features_list=np.vstack(total_feature_list)
650 | base_seq_list=np.vstack(total_base_seq_list)
651 | ref_seq_list=np.vstack(total_ref_seq_list)
652 |
653 | pred_list=[model(batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev)).cpu().numpy() for batch_x, batch_base_seq, batch_ref_seq in generate_batches(features_list, base_seq_list, window, ref_seq=ref_seq_list, batch_size = chunk_size)]
654 |
655 | pred_list=np.vstack(pred_list)
656 | pred_list=np.split(pred_list.ravel(), read_counts)
657 | read_qual_list=[array.array('B',np.round(255*read_pred_list[c_idx]).astype(int)) for read_pred_list, c_idx in zip(pred_list, total_c_idx)]
658 |
659 | output_Q.put([True, total_read_info, total_candidate_list, total_MM_list, read_qual_list, pred_list])
660 |
661 | if len(total_unprocessed_reads)>0:
662 | output_Q.put([False, total_unprocessed_reads])
663 |
664 | return
665 |
666 | def get_input(params, signal_Q, output_Q, input_event):
667 | chrom_list=params['chrom_list']
668 | length_cutoff=params['length_cutoff']
669 |
670 | skip_unmapped=params['skip_unmapped']
671 |
672 | bam=params['bam']
673 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False)
674 |
675 | print('%s: Building BAM index.' %str(datetime.datetime.now()), flush=True)
676 | bam_index=pysam.IndexedReads(bam_file)
677 | bam_index.build()
678 | print('%s: Finished building BAM index.' %str(datetime.datetime.now()), flush=True)
679 |
680 | input_=params['input']
681 | signal_files= [input_] if os.path.isfile(input_) else Path(input_).rglob("*.%s" %params['file_type'])
682 |
683 | chunk=[]
684 | non_primary_reads=[]
685 | reads_per_chunk=100
686 |
687 | max_qsize=1000 if params['seq_type']=='rna' else 200
688 |
689 | if params['file_type']=='fast5':
690 | guppy_group=params['guppy_group']
691 | for filename in signal_files:
692 | with get_fast5_file(filename, mode="r") as f5:
693 | for read in f5.get_reads():
694 | if signal_Q.qsize()>200:
695 | time.sleep(20)
696 | #print('Pausing input due to INPUT queue size limit. Signal_qsize=%d' %(signal_Q.qsize()), flush=True)
697 | read_name=read.read_id
698 | non_primary_reads=[]
699 | try:
700 | read_iter=bam_index.find(read_name)
701 | for bam_read in read_iter:
702 | if (params['ref'] and bam_read.is_mapped and bam_read.reference_name not in chrom_list)\
703 | or bam_read.query_length < length_cutoff \
704 | or (bam_read.is_mapped==False and skip_unmapped==True):
705 | continue
706 |
707 | elif not (bam_read.is_supplementary or bam_read.is_secondary):
708 | read_dict=bam_read.to_dict()
709 | signal=read.get_raw_data()
710 |
711 | if params['fast5_move']:
712 | segment=read.get_analysis_attributes(guppy_group)['segmentation']
713 | start=read.get_analysis_attributes('%s/Summary/segmentation' %segment)['first_sample_template']
714 | stride=read.get_summary_data(guppy_group)['basecall_1d_template']['block_stride']
715 | move_table=read.get_analysis_dataset('%s/BaseCalled_template' %guppy_group, 'Move')
716 | move=(stride, start, move_table)
717 | else:
718 | move=(None,None,None)
719 |
720 | align_data=(bam_read.is_mapped if params['ref'] else False,
721 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length)
722 | data=(signal, move, read_dict, align_data)
723 | chunk.append(data)
724 | if len(chunk)>=reads_per_chunk:
725 | signal_Q.put(chunk)
726 | chunk=[]
727 |
728 | else:
729 | pass
730 | #non_primary_reads.append(read_dict)
731 |
732 | if len(non_primary_reads)>0:
733 | output_Q.put([False, non_primary_reads])
734 |
735 | except KeyError:
736 | continue
737 |
738 | else:
739 | move=(None,None,None)
740 | for filename in signal_files:
741 | with p5.Reader(filename) as reader:
742 | for read in reader.reads():
743 | if signal_Q.qsize()>max_qsize:
744 | time.sleep(20)
745 | print('Pausing input due to INPUT queue size limit. Signal_qsize=%d' %(signal_Q.qsize()), flush=True)
746 |
747 | read_name=str(read.read_id)
748 | non_primary_reads=[]
749 | try:
750 | read_iter=bam_index.find(read_name)
751 | for bam_read in read_iter:
752 |
753 | if (params['ref'] and bam_read.is_mapped and bam_read.reference_name not in chrom_list)\
754 | or bam_read.query_length < length_cutoff \
755 | or (bam_read.is_mapped==False and skip_unmapped==True):
756 | continue
757 |
758 | elif not (bam_read.is_supplementary or bam_read.is_secondary):
759 | read_dict=bam_read.to_dict()
760 | signal=read.signal
761 | align_data=(bam_read.is_mapped if params['ref'] else False,
762 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length)
763 | data=(signal, move, read_dict, align_data)
764 | chunk.append(data)
765 | if len(chunk)>=reads_per_chunk:
766 | signal_Q.put(chunk)
767 | chunk=[]
768 |
769 | else:
770 | pass
771 | #non_primary_reads.append(read_dict)
772 |
773 | if len(non_primary_reads)>0:
774 | output_Q.put([False, non_primary_reads])
775 |
776 | except KeyError:
777 | #print('Read:%s not found in BAM file' %read_name, flush=True)
778 | continue
779 |
780 | if len(chunk)>0:
781 | signal_Q.put(chunk)
782 | chunk=[]
783 |
784 | if len(non_primary_reads)>0:
785 | output_Q.put([False, non_primary_reads])
786 |
787 | input_event.set()
788 | return
789 |
790 | def call_manager(params):
791 | print('%s: Starting Per Read Methylation Detection.' %str(datetime.datetime.now()), flush=True)
792 | if params['dev']!='cpu':
793 | torch.multiprocessing.set_start_method('spawn')
794 |
795 | torch.set_num_threads(1)
796 |
797 | pmanager = mp.Manager()
798 |
799 | bam=params['bam']
800 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False)
801 | header_dict=bam_file.header.to_dict()
802 |
803 | print('%s: Getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True)
804 |
805 | ref_seq_dict={}
806 | ref_pos_dict={}
807 |
808 | mod_positions_list=get_pos(params['mod_positions']) if params['mod_positions'] else None
809 | position_based=True if params['mod_positions'] else False
810 |
811 | if position_based:
812 | params['chrom_list']=[x for x in params['chrom_list'] if x in mod_positions_list.keys()]
813 |
814 | _=get_ref_to_num('ACGT')
815 |
816 | if params['ref'] and len(params['chrom_list'])>0:
817 | with mp.Pool(processes=params['threads']) as pool:
818 | res=pool.map(get_ref_info, zip(repeat(params), params['chrom_list']))
819 | for r in res:
820 | chrom, seq_array, fwd_pos_array, rev_pos_array=r
821 | ref_seq_dict[chrom]=seq_array
822 |
823 | if position_based:
824 | ref_pos_dict[chrom]=(np.array(sorted(list(set(fwd_pos_array)&set(mod_positions_list[chrom][0])))),
825 | np.array(sorted(list(set(rev_pos_array)&set(mod_positions_list[chrom][1])))))
826 |
827 | else:
828 | ref_pos_dict[chrom]=(fwd_pos_array, rev_pos_array)
829 |
830 |
831 | params['position_based']=True if position_based or params['reference_motif_only'] else False
832 |
833 | res=None
834 |
835 | print('%s: Finished getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True)
836 |
837 | signal_Q = pmanager.Queue()
838 | output_Q = pmanager.Queue()
839 | methylation_event=pmanager.Event()
840 | input_event=pmanager.Event()
841 |
842 | handlers = []
843 |
844 | input_process = mp.Process(target=get_input, args=(params, signal_Q, output_Q, input_event))
845 | input_process.start()
846 |
847 | if params['motif_seq']=='CG':
848 | output_process=mp.Process(target=get_cpg_output, args=(params, output_Q, methylation_event, header_dict, ref_pos_dict));
849 | else:
850 | output_process=mp.Process(target=get_output, args=(params, output_Q, methylation_event, header_dict, ref_pos_dict));
851 | output_process.start();
852 |
853 | for hid in range(max(1,params['threads']-1)):
854 | p = mp.Process(target=process, args=(params, ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict));
855 | p.start();
856 | handlers.append(p);
857 |
858 | input_process.join()
859 | print('%s: Reading inputs complete.' %str(datetime.datetime.now()), flush=True)
860 |
861 | for job in handlers:
862 | job.join()
863 |
864 | methylation_event.set()
865 |
866 | print('%s: Model predictions complete. Wrapping up output.' %str(datetime.datetime.now()), flush=True)
867 |
868 | output_process.join()
869 |
870 | return
871 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import torch, math
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
5 | from torch import Tensor
6 |
7 | class OneHotEncode(nn.Module):
8 | def __init__(self, num_classes: int):
9 | super().__init__()
10 | self.num_classes=num_classes
11 | def forward(self, x: Tensor) -> Tensor:
12 | return F.one_hot(x, self.num_classes)
13 |
14 | class RefandReadEmbed(nn.Module):
15 | def __init__(self, embedding_dim, embedding_type):
16 | super().__init__()
17 |
18 | self.embedding_depth=0
19 |
20 | if embedding_type=='one_hot':
21 | self.read_emb=OneHotEncode(4)
22 | self.embedding_depth+=4
23 |
24 | self.ref_emb=OneHotEncode(5)
25 | self.embedding_depth+=5
26 |
27 | elif embedding_type=='learnable':
28 | self.read_emb=nn.Embedding(4, embedding_dim)
29 | self.embedding_depth+=embedding_dim
30 |
31 | self.ref_emb=nn.Embedding(5, embedding_dim)
32 | self.embedding_depth+=embedding_dim
33 |
34 | def forward(self, batch_base_seq, batch_ref_seq):
35 | batch_base_seq_emb=self.read_emb(batch_base_seq)
36 | batch_ref_seq_emb=self.ref_emb(batch_ref_seq)
37 |
38 | return torch.cat((batch_base_seq_emb, batch_ref_seq_emb), 2)
39 |
40 | class ReadEmbed(nn.Module):
41 | def __init__(self, embedding_dim, embedding_type):
42 | super().__init__()
43 |
44 | self.embedding_depth=0
45 |
46 | if embedding_type=='one_hot':
47 | self.read_emb=OneHotEncode(4)
48 | self.embedding_depth+=4
49 |
50 | elif embedding_type=='learnable':
51 | self.read_emb=nn.Embedding(4, embedding_dim)
52 | self.embedding_depth+=embedding_dim
53 |
54 |
55 | def forward(self, batch_base_seq, batch_ref_seq):
56 | batch_base_seq_emb=self.read_emb(batch_base_seq)
57 |
58 | return batch_base_seq_emb
59 |
60 | class SeqEmbed(nn.Module):
61 | def __init__(self, embedding_dim, embedding_type, include_ref):
62 | super().__init__()
63 |
64 | self.embedding_depth=0
65 |
66 | if include_ref:
67 | self.seq_emb=RefandReadEmbed(embedding_dim, embedding_type)
68 |
69 | else:
70 | self.seq_emb=ReadEmbed(embedding_dim, embedding_type)
71 |
72 | self.embedding_depth=self.seq_emb.embedding_depth
73 |
74 | def forward(self, batch_base_seq, batch_ref_seq):
75 | return self.seq_emb(batch_base_seq, batch_ref_seq)
76 |
77 | class PositionalEncoding(nn.Module):
78 | def __init__(self, pe_dim: int, max_len: int):
79 | super().__init__()
80 | position = torch.arange(max_len).unsqueeze(1)
81 | div_term = torch.exp(torch.arange(0, pe_dim, 2) * (-math.log(pe_dim) / (pe_dim)))
82 | pe = torch.zeros(1, max_len, pe_dim)
83 | pe[0,:, 0::2] = torch.sin(position * div_term)
84 | pe[0,:, 1::2] = torch.cos(position * div_term)
85 | self.register_buffer('pe', pe)
86 |
87 | def forward(self, x: Tensor) -> Tensor:
88 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1))
89 | x = torch.cat((x, x_pos),2)
90 | return x
91 |
92 | class PositionalEmbedding(nn.Module):
93 | def __init__(self, pe_dim: int, max_len: int):
94 | super().__init__()
95 | pos=torch.arange(max_len)
96 | self.register_buffer('pos', pos)
97 | self.pe=nn.Embedding(max_len, pe_dim)
98 |
99 | def forward(self, x: Tensor) -> Tensor:
100 | x_pos=self.pe(self.pos)
101 | x_pos=torch.Tensor.repeat(x_pos,(x.size(0),1,1))
102 | x = torch.cat((x, x_pos),2)
103 | return x
104 |
105 | class PositionalParameter(nn.Module):
106 | def __init__(self, pe_dim: int, max_len: int):
107 | super().__init__()
108 |
109 | self.pe=torch.nn.Parameter(torch.randn(max_len, pe_dim))
110 |
111 | def forward(self, x: Tensor) -> Tensor:
112 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1))
113 | x = torch.cat((x, x_pos),2)
114 | return x
115 |
116 | class ClassifierMiddle(nn.Module):
117 | def __init__(self, in_dim: int, num_fc: int, model_len: int):
118 | super().__init__()
119 | self.mid = model_len//2
120 | self.fc = nn.Linear(in_dim, num_fc)
121 | self.out = nn.Linear(num_fc,1)
122 |
123 | def forward(self, x):
124 | x = F.relu(self.fc(x[:,self.mid, :]))
125 | x=self.out(x)
126 | return x
127 |
128 | class ClassifierAll(nn.Module):
129 | def __init__(self, in_dim: int, num_fc: int):
130 | super().__init__()
131 | self.fc = nn.Linear(in_dim, num_fc)
132 | self.out = nn.Linear(num_fc,1)
133 |
134 | def forward(self, x):
135 | x = torch.flatten(x, 1)
136 | x = F.relu(self.fc(x))
137 | x=self.out(x)
138 | return x
139 |
140 | class BiLSTM(nn.Module):
141 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, fc_type):
142 | super(BiLSTM, self).__init__()
143 |
144 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref)
145 | self.model_len=model_dims[0]
146 | self.model_depth=model_dims[1]+self.emb.embedding_depth
147 |
148 | self.bilstm = nn.LSTM(input_size=self.model_depth, hidden_size=dim_feedforward, num_layers=num_layers, bidirectional=True, batch_first = True)
149 |
150 | if fc_type=='middle':
151 | self.classifier=ClassifierMiddle(in_dim=dim_feedforward*2, num_fc=num_fc, model_len=self.model_len)
152 |
153 | else:
154 | self.classifier=ClassifierAll(in_dim=self.model_len*dim_feedforward*2, num_fc=num_fc)
155 |
156 | def forward(self, batch_x, batch_base_seq, batch_ref_seq):
157 | seq_emb=self.emb(batch_base_seq, batch_ref_seq)
158 | x=torch.cat((batch_x, seq_emb), 2)
159 | x, _=self.bilstm(x)
160 | x = self.classifier(x)
161 | x=torch.nn.functional.sigmoid(x)
162 | return x
163 |
164 | class TransformerModel(nn.Module):
165 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, pe_dim, nhead, pe_type, fc_type):
166 | super(TransformerModel, self).__init__()
167 |
168 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref)
169 | self.model_len=model_dims[0]
170 |
171 | if pe_type=='fixed':
172 | self.pe_block=PositionalEncoding(pe_dim=pe_dim, max_len=self.model_len)
173 |
174 | elif pe_type=='embedding':
175 | self.pe_block=PositionalEmbedding(pe_dim=pe_dim, max_len=self.model_len)
176 |
177 | elif pe_type=='parameter':
178 | self.pe_block=PositionalParameter(pe_dim=pe_dim, max_len=self.model_len)
179 |
180 | self.model_depth=model_dims[1]+self.emb.embedding_depth+pe_dim
181 | self.pad_length=math.ceil(self.model_depth/nhead)*nhead-self.model_depth
182 | pad=torch.zeros(1,self.model_len, self.pad_length)
183 | self.register_buffer('pad', pad)
184 | self.model_depth+=self.pad_length
185 |
186 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_depth, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0, batch_first = True)
187 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
188 |
189 | if fc_type=='middle':
190 | self.classifier=ClassifierMiddle(in_dim=self.model_depth, num_fc=num_fc, model_len=self.model_len)
191 |
192 | else:
193 | self.classifier=ClassifierAll(in_dim=self.model_len*self.model_depth, num_fc=num_fc)
194 |
195 | def forward(self, batch_x, batch_base_seq, batch_ref_seq):
196 | seq_emb=self.emb(batch_base_seq, batch_ref_seq)
197 | x=torch.cat((batch_x, seq_emb), 2)
198 | x=self.pe_block(x)
199 | x_pad=torch.Tensor.repeat(self.pad,(x.size(0),1,1))
200 | x = torch.cat((x, x_pad),2)
201 |
202 | x=self.transformer_encoder(x)
203 | x = self.classifier(x)
204 | x=torch.nn.functional.sigmoid(x)
205 | return x
--------------------------------------------------------------------------------
/src/models/bilstm.cfg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm.cfg
--------------------------------------------------------------------------------
/src/models/bilstm/R10.4.1_4kHz_v3.5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_4kHz_v3.5
--------------------------------------------------------------------------------
/src/models/bilstm/R10.4.1_4kHz_v4.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_4kHz_v4.1
--------------------------------------------------------------------------------
/src/models/bilstm/R10.4.1_5kHz_v4.3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_5kHz_v4.3
--------------------------------------------------------------------------------
/src/models/bilstm/R10.4.1_5kHz_v5.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R10.4.1_5kHz_v5.0
--------------------------------------------------------------------------------
/src/models/bilstm/R9.4.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/bilstm/R9.4.1
--------------------------------------------------------------------------------
/src/models/transformer.cfg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer.cfg
--------------------------------------------------------------------------------
/src/models/transformer/R10.4.1_4kHz_v3.5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_4kHz_v3.5
--------------------------------------------------------------------------------
/src/models/transformer/R10.4.1_4kHz_v4.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_4kHz_v4.1
--------------------------------------------------------------------------------
/src/models/transformer/R10.4.1_5kHz_v4.3:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_5kHz_v4.3
--------------------------------------------------------------------------------
/src/models/transformer/R10.4.1_5kHz_v5.0:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R10.4.1_5kHz_v5.0
--------------------------------------------------------------------------------
/src/models/transformer/R9.4.1:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WGLab/DeepMod2/d674273150c03b57dc6cd0f38cf45966595aecf7/src/models/transformer/R9.4.1
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from subprocess import PIPE, Popen
2 | import os, shutil, pysam, sys, datetime, re, pickle
3 | import numpy as np
4 | from numba import jit
5 | import torch
6 | from itertools import repeat
7 | import pysam
8 | from .models import *
9 | import torch.nn.utils.prune as prune
10 | from tqdm import tqdm
11 |
12 | model_dict={
13 |
14 | 'bilstm_r9.4.1' : { 'path' : 'models/bilstm/R9.4.1',
15 | 'help':'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R9.4.1 flowcells.',
16 | 'model_config_path':'models/bilstm.cfg'},
17 |
18 | 'bilstm_r10.4.1_4khz_v3.5': { 'path' : 'models/bilstm/R10.4.1_4kHz_v3.5',
19 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v3.5 Guppy/Dorado basecaller model.',
20 | 'model_config_path':'models/bilstm.cfg'},
21 |
22 | 'bilstm_r10.4.1_4khz_v4.1': { 'path' : 'models/bilstm/R10.4.1_4kHz_v4.1',
23 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v4.1 Guppy/Dorado basecaller model.',
24 | 'model_config_path':'models/bilstm.cfg'},
25 |
26 | 'bilstm_r10.4.1_5khz_v4.3': { 'path' : 'models/bilstm/R10.4.1_5kHz_v4.3',
27 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v4.3 Guppy/Dorado basecaller model.',
28 | 'model_config_path':'models/bilstm.cfg'},
29 |
30 | 'bilstm_r10.4.1_5khz_v5.0': { 'path' : 'models/bilstm/R10.4.1_5kHz_v5.0',
31 | 'help': 'BiLSTM model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v5.0 Guppy/Dorado basecaller model.',
32 | 'model_config_path':'models/bilstm.cfg'},
33 |
34 |
35 | 'transformer_r9.4.1' : { 'path' : 'models/transformer/R9.4.1',
36 | 'help':'Transformer model trained on chr2-21 of HG002, HG003 and HG004 R9.4.1 flowcells.',
37 | 'model_config_path':'models/transformer.cfg'},
38 |
39 | 'transformer_r10.4.1_4khz_v3.5': { 'path' : 'models/transformer/R10.4.1_4kHz_v3.5',
40 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v3.5 Guppy/Dorado basecaller model.',
41 | 'model_config_path':'models/transformer.cfg'},
42 |
43 | 'transformer_r10.4.1_4khz_v4.1': { 'path' : 'models/transformer/R10.4.1_4kHz_v4.1',
44 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 4kHz sampling, with basecalling performed by v4.1 Guppy/Dorado basecaller model.',
45 | 'model_config_path':'models/transformer.cfg'},
46 |
47 | 'transformer_r10.4.1_5khz_v4.3': { 'path' : 'models/transformer/R10.4.1_5kHz_v4.3',
48 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v4.3 Guppy/Dorado basecaller model.',
49 | 'model_config_path':'models/transformer.cfg'},
50 |
51 | 'transformer_r10.4.1_5khz_v5.0': { 'path' : 'models/transformer/R10.4.1_5kHz_v5.0',
52 | 'help': 'Transfromer model trained on chr2-21 of HG002, HG003 and HG004 R10.4.1 flowcells with 5kHz sampling, with basecalling performed by v5.0 Guppy/Dorado basecaller model.',
53 | 'model_config_path':'models/transformer.cfg'},
54 | }
55 |
56 | comp_base_map={'A':'T','T':'A','C':'G','G':'C','[':']', ']':'['}
57 |
58 | def revcomp(s):
59 | return ''.join(comp_base_map[x] for x in s[::-1])
60 |
61 | def get_model_help():
62 | for n,model in enumerate(model_dict):
63 | print('-'*30)
64 | print('%d) Model Name: %s' %(n+1, model))
65 | print('Details: %s\n' %model_dict[model]['help'])
66 |
67 | def get_model(params):
68 | model_name=params['model']
69 | model_config_path=None
70 | model_path=None
71 |
72 | if model_name in model_dict:
73 | dirname = os.path.dirname(__file__)
74 | model_info=model_dict[model_name]
75 | model_config_path = os.path.join(dirname, model_info['model_config_path'])
76 | model_path = os.path.join(dirname, model_info['path'])
77 |
78 | else:
79 | try:
80 | model_config_path = model_name.split(',')[0]
81 | model_path = model_name.split(',')[1]
82 |
83 | except IndexError:
84 | print('Incorrect model specified')
85 | sys.exit(2)
86 |
87 | with open(model_config_path, 'rb') as handle:
88 | model_config = pickle.load(handle)
89 |
90 | if model_config['model_type']=='bilstm':
91 | model = BiLSTM(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \
92 | dim_feedforward=model_config['dim_feedforward'], \
93 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \
94 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'], \
95 | fc_type=model_config['fc_type']);
96 |
97 | checkpoint = torch.load(model_path, map_location ='cpu')
98 | model.load_state_dict(checkpoint['model_state_dict'])
99 |
100 | if not params['disable_pruning']:
101 | module=model.classifier.fc
102 | prune.l1_unstructured(module, name="weight", amount=0.95)
103 | prune.remove(module, 'weight')
104 |
105 | return model, model_config
106 |
107 | elif model_config['model_type']=='transformer':
108 | model = TransformerModel(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \
109 | dim_feedforward=model_config['dim_feedforward'], \
110 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \
111 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'],\
112 | pe_dim=model_config['pe_dim'], nhead=model_config['nhead'], \
113 | pe_type=model_config['pe_type'], fc_type=model_config['fc_type']);
114 |
115 | checkpoint = torch.load(model_path, map_location ='cpu')
116 | model.load_state_dict(checkpoint['model_state_dict'])
117 |
118 | if not params['disable_pruning']:
119 | module=model.classifier.fc
120 | prune.l1_unstructured(module, name="weight", amount=0.5)
121 | prune.remove(module, 'weight')
122 | for l in model.transformer_encoder.layers:
123 | module=l.linear1
124 | prune.l1_unstructured(module, name="weight", amount=0.25)
125 | prune.remove(module, 'weight')
126 | module=l.linear2
127 | prune.l1_unstructured(module, name="weight", amount=0.25)
128 | prune.remove(module, 'weight')
129 | module=l.self_attn.out_proj
130 | prune.l1_unstructured(module, name="weight", amount=0.25)
131 | prune.remove(module, 'weight')
132 |
133 | return model, model_config
134 |
135 | else:
136 | print('Model: %s not found.' %params['model'], flush=True)
137 | sys.exit(2)
138 |
139 | def generate_batches(features, base_seq, window, ref_seq=None, batch_size=512):
140 | if len(ref_seq)==0:
141 | ref_seq=(4+torch.zeros(features.shape[0], 2*window+1)).type(torch.LongTensor)
142 | else:
143 | ref_seq=torch.Tensor(ref_seq).type(torch.LongTensor)
144 |
145 | features=torch.Tensor(features)
146 | base_seq=torch.Tensor(base_seq).type(torch.LongTensor)
147 |
148 | for local_index in range(0, features.shape[0], batch_size):
149 | batch_x=features[local_index:(local_index + batch_size)]
150 | batch_base_seq=base_seq[local_index:(local_index + batch_size)]
151 | batch_ref_seq=ref_seq[local_index:(local_index + batch_size)]
152 |
153 | yield batch_x, batch_base_seq, batch_ref_seq
154 |
155 | @jit(nopython=True)
156 | def get_aligned_pairs(cigar_tuples, ref_start):
157 | alen=np.sum(cigar_tuples[:,0])
158 | pairs=np.zeros((alen,2)).astype(np.int32)
159 |
160 | i=0
161 | ref_cord=ref_start-1
162 | read_cord=-1
163 | pair_cord=0
164 | for i in range(len(cigar_tuples)):
165 | len_op, op= cigar_tuples[i,0], cigar_tuples[i,1]
166 | if op==0:
167 | for k in range(len_op):
168 | ref_cord+=1
169 | read_cord+=1
170 |
171 | pairs[pair_cord,0]=read_cord
172 | pairs[pair_cord,1]=ref_cord
173 | pair_cord+=1
174 |
175 | elif op==2:
176 | for k in range(len_op):
177 | read_cord+=1
178 | pairs[pair_cord,0]=read_cord
179 | pairs[pair_cord,1]=-1
180 | pair_cord+=1
181 |
182 | elif op==1:
183 | for k in range(len_op):
184 | ref_cord+=1
185 | pairs[pair_cord,0]=-1
186 | pairs[pair_cord,1]=ref_cord
187 | pair_cord+=1
188 | return pairs
189 |
190 | @jit(nopython=True)
191 | def get_ref_to_num(x):
192 | b=np.full((len(x)+1,2),fill_value=0,dtype=np.int8)
193 |
194 | for i,l in enumerate(x):
195 | if l=='A':
196 | b[i,0]=0
197 | b[i,1]=3
198 |
199 | elif l=='T':
200 | b[i,0]=3
201 | b[i,1]=0
202 |
203 | elif l=='C':
204 | b[i,0]=1
205 | b[i,1]=2
206 |
207 | elif l=='G':
208 | b[i,0]=2
209 | b[i,1]=1
210 |
211 | else:
212 | b[i,0]=4
213 | b[i,1]=4
214 |
215 | b[-1,0]=4
216 | b[-1,1]=4
217 |
218 | return b
219 |
220 | def get_pos(path):
221 | labelled_pos_list={}
222 | strand_map={'+':0, '-':1}
223 |
224 | with open(path) as file:
225 | for line in file:
226 | line=line.rstrip('\n').split('\t')
227 | if line[0] not in labelled_pos_list:
228 | labelled_pos_list[line[0]]={0:[], 1:[]}
229 |
230 | labelled_pos_list[line[0]][strand_map[line[2]]].append(int(line[1]))
231 |
232 | return labelled_pos_list
233 |
234 | def motif_check(motif):
235 | nt_dict={'R': 'GA',
236 | 'Y': 'CT',
237 | 'K': 'GT',
238 | 'M': 'AC',
239 | 'S': 'GC',
240 | 'W': 'AT',
241 | 'B': 'GTC',
242 | 'D': 'GAT',
243 | 'H': 'ACT',
244 | 'V': 'GCA',
245 | 'N': 'AGCT'}
246 |
247 | valid_alphabet=set(nt_dict.keys()).union({'A', 'C', 'G', 'T'})
248 |
249 | exp_motif_seq, final_motif_ind, valid = None, None, False
250 |
251 | if len(motif)<2:
252 | print('--motif not specified correctly. You need to specify a motif and at least one index',flush=True)
253 | return motif_seq, exp_motif_seq, final_motif_ind, valid
254 |
255 | elif len(set(motif[0])-valid_alphabet)>0:
256 | print('--motif not specified correctly. Motif should only consist of the following extended nucleotide letters: {}'.format(','.join(valid_alphabet)),flush=True)
257 | return motif_seq, exp_motif_seq, final_motif_ind, valid
258 |
259 | elif all([a.isnumeric() for a in motif[1:]])==False:
260 | print('--motif not specified correctly. Motif indices should be integers separated by whitespace and shoud come after the motif sequence.',flush=True)
261 | return motif_seq, exp_motif_seq, final_motif_ind, valid
262 |
263 | else:
264 | motif_seq=motif[0]
265 | motif_ind=[int(x) for x in motif[1:]]
266 |
267 | if len(set(motif_seq[x] for x in motif_ind))!=1 or len(set(motif_seq[x] for x in motif_ind)-set('ACGT'))>0:
268 | print('Base of interest should be same for all indices and must be one of A, C, G or T.', flush=True)
269 | return motif_seq, exp_motif_seq, final_motif_ind, valid
270 |
271 | else:
272 | exp_motif_seq=motif_seq
273 | for nt in nt_dict:
274 | if nt in exp_motif_seq:
275 | exp_motif_seq=exp_motif_seq.replace(nt, '[{}]'.format(nt_dict[nt]))
276 | return motif_seq, exp_motif_seq, motif_ind, True
277 |
278 | def get_ref_info(args):
279 | params, chrom=args
280 | motif_seq, motif_ind=params['motif_seq'], params['motif_ind']
281 | exp_motif_seq=params['exp_motif_seq']
282 | ref_fasta=pysam.FastaFile(params['ref'])
283 | seq=ref_fasta.fetch(chrom).upper()
284 | seq_array=get_ref_to_num(seq)
285 |
286 |
287 | fwd_motif_anchor=np.array([m.start(0) for m in re.finditer(r'(?={})'.format(exp_motif_seq), seq)])
288 | rev_motif_anchor=np.array([m.start(0) for m in re.finditer(r'(?={})'.format(revcomp(exp_motif_seq)), seq)])
289 |
290 | fwd_pos_array=np.array(sorted(list(set.union(*[set(fwd_motif_anchor+i) for i in motif_ind])))).astype(int)
291 | rev_pos_array=np.array(sorted(list(set.union(*[set(rev_motif_anchor+len(motif_seq)-1-i) for i in motif_ind])))).astype(int)
292 |
293 | return chrom, seq_array, fwd_pos_array, rev_pos_array
294 |
295 | def get_stats_string_cpg(chrom, pos, is_ref_cpg, cpg):
296 | unphased_rev_unmod, unphased_rev_mod, unphased_fwd_unmod, unphased_fwd_mod=cpg[0:4]
297 | phase1_rev_unmod, phase1_rev_mod, phase1_fwd_unmod, phase1_fwd_mod=cpg[4:8]
298 | phase2_rev_unmod, phase2_rev_mod, phase2_fwd_unmod, phase2_fwd_mod=cpg[8:12]
299 |
300 | fwd_mod=unphased_fwd_mod+phase1_fwd_mod+phase2_fwd_mod
301 | fwd_unmod=unphased_fwd_unmod+phase1_fwd_unmod+phase2_fwd_unmod
302 | fwd_total_stats=[fwd_mod+fwd_unmod,fwd_mod,fwd_unmod,fwd_mod/(fwd_mod+fwd_unmod) if fwd_mod+fwd_unmod>0 else 0]
303 | fwd_phase1_stats=[phase1_fwd_mod+phase1_fwd_unmod, phase1_fwd_mod, phase1_fwd_unmod, phase1_fwd_mod/(phase1_fwd_mod+phase1_fwd_unmod) if phase1_fwd_mod+phase1_fwd_unmod>0 else 0]
304 | fwd_phase2_stats=[phase2_fwd_mod+phase2_fwd_unmod, phase2_fwd_mod, phase2_fwd_unmod, phase2_fwd_mod/(phase2_fwd_mod+phase2_fwd_unmod) if phase2_fwd_mod+phase2_fwd_unmod>0 else 0]
305 |
306 | fwd_str='{}\t{}\t{}\t+\t{}\t'.format(chrom, pos, pos+1, is_ref_cpg)+'{}\t{}\t{}\t{:.4f}\t'.format(*fwd_total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*fwd_phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*fwd_phase2_stats)
307 |
308 |
309 | rev_mod=unphased_rev_mod+phase1_rev_mod+phase2_rev_mod
310 | rev_unmod=unphased_rev_unmod+phase1_rev_unmod+phase2_rev_unmod
311 | rev_total_stats=[rev_mod+rev_unmod,rev_mod,rev_unmod,rev_mod/(rev_mod+rev_unmod) if rev_mod+rev_unmod>0 else 0]
312 | rev_phase1_stats=[phase1_rev_mod+phase1_rev_unmod, phase1_rev_mod, phase1_rev_unmod, phase1_rev_mod/(phase1_rev_mod+phase1_rev_unmod) if phase1_rev_mod+phase1_rev_unmod>0 else 0]
313 | rev_phase2_stats=[phase2_rev_mod+phase2_rev_unmod, phase2_rev_mod, phase2_rev_unmod, phase2_rev_mod/(phase2_rev_mod+phase2_rev_unmod) if phase2_rev_mod+phase2_rev_unmod>0 else 0]
314 |
315 | rev_str='{}\t{}\t{}\t-\t{}\t'.format(chrom, pos+1, pos+2, is_ref_cpg)+'{}\t{}\t{}\t{:.4f}\t'.format(*rev_total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*rev_phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*rev_phase2_stats)
316 |
317 |
318 | agg_total_stats=[fwd_total_stats[0]+rev_total_stats[0], fwd_total_stats[1]+rev_total_stats[1], fwd_total_stats[2]+rev_total_stats[2], (fwd_total_stats[1]+rev_total_stats[1])/(fwd_total_stats[0]+rev_total_stats[0]) if fwd_total_stats[0]+rev_total_stats[0]>0 else 0]
319 |
320 | agg_phase1_stats=[fwd_phase1_stats[0]+rev_phase1_stats[0], fwd_phase1_stats[1]+rev_phase1_stats[1], fwd_phase1_stats[2]+rev_phase1_stats[2], (fwd_phase1_stats[1]+rev_phase1_stats[1])/(fwd_phase1_stats[0]+rev_phase1_stats[0]) if fwd_phase1_stats[0]+rev_phase1_stats[0]>0 else 0]
321 |
322 | agg_phase2_stats=[fwd_phase2_stats[0]+rev_phase2_stats[0], fwd_phase2_stats[1]+rev_phase2_stats[1], fwd_phase2_stats[2]+rev_phase2_stats[2], (fwd_phase2_stats[1]+rev_phase2_stats[1])/(fwd_phase2_stats[0]+rev_phase2_stats[0]) if fwd_phase2_stats[0]+rev_phase2_stats[0]>0 else 0]
323 |
324 | agg_str='{}\t{}\t{}\t{}\t'.format(chrom, pos, pos+2, is_ref_cpg)+'{}\t{}\t{}\t{:.4f}\t'.format(*agg_total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*agg_phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*agg_phase2_stats)
325 |
326 | return [(agg_total_stats[0], agg_str),(fwd_total_stats[0], fwd_str),(rev_total_stats[0], rev_str)]
327 |
328 |
329 | def get_stats_string(chrom, pos, strand, mod_call):
330 | unphased_unmod, unphased_mod=mod_call[0:2]
331 | phase1_unmod, phase1_mod=mod_call[2:4]
332 | phase2_unmod, phase2_mod=mod_call[4:6]
333 |
334 | mod=unphased_mod+phase1_mod+phase2_mod
335 | unmod=unphased_unmod+phase1_unmod+phase2_unmod
336 |
337 | total_stats=[mod+unmod, mod, unmod,mod/(mod+unmod) if mod+unmod>0 else 0]
338 | phase1_stats=[phase1_mod+phase1_unmod, phase1_mod, phase1_unmod, phase1_mod/(phase1_mod+phase1_unmod) if phase1_mod+phase1_unmod>0 else 0]
339 | phase2_stats=[phase2_mod+phase2_unmod, phase2_mod, phase2_unmod, phase2_mod/(phase2_mod+phase2_unmod) if phase2_mod+phase2_unmod>0 else 0]
340 |
341 | mod_str='{}\t{}\t{}\t{}\t'.format(chrom, pos, pos+1, strand)+'{}\t{}\t{}\t{:.4f}\t'.format(*total_stats) + '{}\t{}\t{}\t{:.4f}\t'.format(*phase1_stats) + '{}\t{}\t{}\t{:.4f}\n'.format(*phase2_stats)
342 |
343 | return total_stats[0], mod_str
344 |
345 | def get_cpg_output(params, input_list):
346 | qscore_cutoff=params['qscore_cutoff']
347 | length_cutoff=params['length_cutoff']
348 |
349 | mod_threshold=params['mod_t']
350 | unmod_threshold=params['unmod_t']
351 |
352 | cpg_ref_only=not params['include_non_cpg_ref']
353 |
354 |
355 |
356 | total_files=len(input_list)
357 | print('%s: Reading %d files.' %(str(datetime.datetime.now()), total_files), flush=True)
358 | pbar = tqdm(total=total_files)
359 |
360 | per_site_pred={}
361 |
362 | for read_pred_file in input_list:
363 | with open(read_pred_file,'r') as read_file:
364 | read_file.readline()
365 | for line in read_file:
366 | read, chrom, pos, pos_after, read_pos, strand, score, mean_qscore, sequence_length, phase, is_ref_cpg = line.rstrip('\n').split('\t')
367 |
368 | if pos=='NA' or float(mean_qscore)unmod_threshold:
374 | continue
375 | else:
376 | mod=score>=mod_threshold
377 |
378 | pos=int(pos)
379 | phase=int(phase)
380 | is_forward=1 if strand=='+' else 0
381 |
382 | idx=4*phase+2*is_forward
383 |
384 | is_ref_cpg=True if is_ref_cpg =='True' else False
385 | zero_based_fwd_pos=pos if strand=='+' else pos-1
386 |
387 | if (chrom, zero_based_fwd_pos) not in per_site_pred:
388 | per_site_pred[(chrom, zero_based_fwd_pos)]=[0]*12+[is_ref_cpg]
389 |
390 | per_site_pred[(chrom, zero_based_fwd_pos)][idx+mod]+=1
391 |
392 | pbar.update(1)
393 | pbar.close()
394 |
395 | print('%s: Writing Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True)
396 |
397 | per_site_fields=['#chromosome', 'position_before', 'position','strand', 'ref_cpg',
398 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction',
399 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1',
400 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2']
401 | per_site_header='\t'.join(per_site_fields)+'\n'
402 | per_site_fields.remove('strand')
403 | agg_per_site_header='\t'.join(per_site_fields)+'\n'
404 |
405 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix'])
406 | agg_per_site_file_path=os.path.join(params['output'],'%s.per_site.aggregated' %params['prefix'])
407 |
408 | with open(per_site_file_path, 'w') as per_site_file, open(agg_per_site_file_path,'w') as agg_per_site_file:
409 | per_site_file.write(per_site_header)
410 | agg_per_site_file.write(agg_per_site_header)
411 |
412 | for x in sorted(per_site_pred.keys()):
413 | chrom, pos=x
414 | cpg=per_site_pred[x]
415 | is_ref_cpg=cpg[12]
416 |
417 | if cpg_ref_only and is_ref_cpg==False:
418 | continue
419 | #fwd_stats=[self.chrom, self.position, self.position+1, '+', self.is_ref_cpg]+self.get_all_phases().forward.stats() + self.phase_1.forward.stats() + self.phase_2.forward.stats()
420 |
421 | agg_stats, fwd_stats, rev_stats=get_stats_string_cpg(chrom, pos, is_ref_cpg, cpg)
422 | if agg_stats[0]>0:
423 | agg_per_site_file.write(agg_stats[1])
424 |
425 | if fwd_stats[0]>0:
426 | per_site_file.write(fwd_stats[1])
427 |
428 | if rev_stats[0]>0:
429 | per_site_file.write(rev_stats[1])
430 |
431 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True)
432 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True)
433 | print('%s: Aggregated Per Site Prediction file: %s' %(str(datetime.datetime.now()), agg_per_site_file_path), flush=True)
434 |
435 | def get_output(params, input_list):
436 | qscore_cutoff=params['qscore_cutoff']
437 | length_cutoff=params['length_cutoff']
438 |
439 | mod_threshold=params['mod_t']
440 | unmod_threshold=params['unmod_t']
441 |
442 | cpg_ref_only=not params['include_non_cpg_ref']
443 |
444 |
445 |
446 | total_files=len(input_list)
447 | print('%s: Reading %d files.' %(str(datetime.datetime.now()), total_files), flush=True)
448 | pbar = tqdm(total=total_files)
449 |
450 | per_site_pred={}
451 |
452 | for read_pred_file in input_list:
453 | with open(read_pred_file,'r') as read_file:
454 | read_file.readline()
455 | for line in read_file:
456 | read, chrom, pos, pos_after, read_pos, strand, score, mean_qscore, sequence_length, phase = line.rstrip('\n').split('\t')
457 |
458 | if pos=='NA' or float(mean_qscore)unmod_threshold:
464 | continue
465 | else:
466 | mod=score>=mod_threshold
467 |
468 | pos=int(pos)
469 | phase=int(phase)
470 |
471 | if (chrom, pos,strand) not in per_site_pred:
472 | per_site_pred[(chrom, pos,strand)]=[0]*6
473 |
474 | per_site_pred[(chrom, pos,strand)][2*phase+mod]+=1
475 |
476 | pbar.update(1)
477 | pbar.close()
478 |
479 | print('%s: Writing Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True)
480 |
481 | per_site_fields=['#chromosome', 'position_before', 'position','strand',
482 | 'coverage','mod_coverage', 'unmod_coverage','mod_fraction',
483 | 'coverage_phase1','mod_coverage_phase1', 'unmod_coverage_phase1','mod_fraction_phase1',
484 | 'coverage_phase2','mod_coverage_phase2', 'unmod_coverage_phase2','mod_fraction_phase2']
485 | per_site_header='\t'.join(per_site_fields)+'\n'
486 |
487 | per_site_file_path=os.path.join(params['output'],'%s.per_site' %params['prefix'])
488 |
489 | with open(per_site_file_path, 'w') as per_site_file:
490 | per_site_file.write(per_site_header)
491 |
492 | for x in sorted(per_site_pred.keys()):
493 | chrom, pos, strand=x
494 | mod_call=per_site_pred[x]
495 |
496 | stats=get_stats_string(chrom, pos, strand, mod_call)
497 | if stats[0]>0:
498 | per_site_file.write(stats[1])
499 |
500 | print('%s: Finished Writing Per Site Methylation Output.' %str(datetime.datetime.now()), flush=True)
501 | print('%s: Per Site Prediction file: %s' %(str(datetime.datetime.now()), per_site_file_path), flush=True)
502 |
503 | def get_per_site(params, input_list):
504 | print('%s: Starting Per Site Methylation Detection.' %str(datetime.datetime.now()), flush=True)
505 |
506 | if params['cpg_output']:
507 | get_cpg_output(params, input_list)
508 | else:
509 | get_output(params, input_list)
510 |
511 |
--------------------------------------------------------------------------------
/train/generate_features.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict, ChainMap
2 | import time, itertools, h5py, pysam
3 | import datetime, os, shutil, argparse, sys, re, array
4 | import os
5 | from itertools import repeat
6 | import multiprocessing as mp
7 | import numpy as np
8 | from pathlib import Path
9 | from ont_fast5_api.fast5_interface import get_fast5_file
10 | from numba import jit
11 | import queue, gzip
12 | import pod5 as p5
13 | import utils
14 |
15 | base_to_num_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3,'N':4}
16 |
17 | num_to_base_map={0:'A', 1:'C', 2:'G', 3:'T', 4:'N'}
18 |
19 | comp_base_map={'A':'T','T':'A','C':'G','G':'C','[':']', ']':'['}
20 |
21 | def revcomp(s):
22 | return ''.join(comp_base_map[x] for x in s[::-1])
23 |
24 | def motif_check(motif):
25 | nt_dict={'R': 'GA',
26 | 'Y': 'CT',
27 | 'K': 'GT',
28 | 'M': 'AC',
29 | 'S': 'GC',
30 | 'W': 'AT',
31 | 'B': 'GTC',
32 | 'D': 'GAT',
33 | 'H': 'ACT',
34 | 'V': 'GCA',
35 | 'N': 'AGCT'}
36 |
37 | valid_alphabet=set(nt_dict.keys()).union({'A', 'C', 'G', 'T'})
38 |
39 | exp_motif_seq, final_motif_ind, valid = None, None, False
40 |
41 | if len(motif)<2:
42 | print('--motif not specified correctly. You need to specify a motif and at least one index',flush=True)
43 | return motif_seq, exp_motif_seq, final_motif_ind, valid
44 |
45 | elif len(set(motif[0])-valid_alphabet)>0:
46 | print('--motif not specified correctly. Motif should only consist of the following extended nucleotide letters: {}'.format(','.join(valid_alphabet)),flush=True)
47 | return motif_seq, exp_motif_seq, final_motif_ind, valid
48 |
49 | elif all([a.isnumeric() for a in motif[1:]])==False:
50 | print('--motif not specified correctly. Motif indices should be integers separated by whitespace and shoud come after the motif sequence.',flush=True)
51 | return motif_seq, exp_motif_seq, final_motif_ind, valid
52 |
53 | else:
54 | motif_seq=motif[0]
55 | motif_ind=[int(x) for x in motif[1:]]
56 |
57 | if len(set(motif_seq[x] for x in motif_ind))!=1 or len(set(motif_seq[x] for x in motif_ind)-set('ACGT'))>0:
58 | print('Base of interest should be same for all indices and must be one of A, C, G or T.', flush=True)
59 | return motif_seq, exp_motif_seq, final_motif_ind, valid
60 |
61 | else:
62 | exp_motif_seq=motif_seq
63 | for nt in nt_dict:
64 | if nt in exp_motif_seq:
65 | exp_motif_seq=exp_motif_seq.replace(nt, '[{}]'.format(nt_dict[nt]))
66 | return motif_seq, exp_motif_seq, motif_ind, True
67 |
68 | def get_candidates(read_seq, align_data, aligned_pairs, ref_pos_dict, ):
69 | is_mapped, is_forward, ref_name, reference_start, reference_end, read_length=align_data
70 |
71 | ref_motif_pos=ref_pos_dict[ref_name][0] if is_forward else ref_pos_dict[ref_name][1]
72 |
73 | common_pos=ref_motif_pos[(ref_motif_pos>=reference_start)&(ref_motif_pos5]=5
187 | signal[signal<-5]=-5
188 |
189 | move_len=len(move_table)
190 | move_index=np.where(move_table)[0]
191 | rlen=len(move_index)
192 |
193 | data=np.zeros((rlen,9))
194 |
195 | for i in range(len(move_index)-1):
196 | prev=move_index[i]*stride+start
197 | sig_end=move_index[i+1]*stride+start
198 |
199 | sig_len=sig_end-prev
200 | data[i, 8]=np.log10(sig_len)
201 | data[i, 4]=np.median(signal[prev:sig_end])
202 | data[i, 5]=np.median(np.abs(signal[prev:sig_end]-data[i, 4]))
203 | data[i, 6]=np.mean(signal[prev:sig_end])
204 | data[i, 7]=np.std(signal[prev:sig_end])
205 |
206 | for j in range(4):
207 | tmp_cnt=0
208 | for t in range(j*sig_len//4,min(sig_len, (j+1)*sig_len//4)):
209 | data[i, j]+=signal[t+prev]
210 | tmp_cnt+=1
211 | data[i, j]=data[i, j]/tmp_cnt
212 |
213 | return data
214 |
215 | def get_pos(path):
216 | labelled_pos_list={}
217 | strand_map={'+':0, '-':1}
218 |
219 | with open(path) as file:
220 | for line in file:
221 | line=line.rstrip('\n').split('\t')
222 | if line[0] not in labelled_pos_list:
223 | labelled_pos_list[line[0]]={0:{}, 1:{}}
224 |
225 | labelled_pos_list[line[0]][strand_map[line[2]]][int(line[1])]=float(line[3])
226 |
227 | return labelled_pos_list
228 |
229 | def write_to_npz(output_file_path, mat, base_qual, base_seq, ref_seq, label, ref_coordinates, read_name, ref_name, window, norm_type):
230 | np.savez(output_file_path, mat=mat, base_qual=base_qual, base_seq=base_seq, ref_seq=ref_seq, label=label, ref_coordinates=ref_coordinates, read_name=read_name, ref_name=ref_name, window=window, norm_type=norm_type, strides_per_base=1, model_depth=mat.shape[2], full_signal=False)
231 |
232 | def get_output(params, output_Q, process_event):
233 | output=params['output']
234 |
235 | reads_per_chunk=params['reads_per_chunk']
236 |
237 | chunk=1
238 | read_count=0
239 |
240 | output_file_path=os.path.join(output,'%s.features.%d.npz' %(params['prefix'], chunk))
241 |
242 | mat, base_qual, base_seq, ref_seq, label=[], [], [], [], []
243 | ref_coordinates, read_name, ref_name= [], [], []
244 |
245 | while True:
246 | if process_event.is_set() and output_Q.empty():
247 | break
248 | else:
249 | try:
250 | res = output_Q.get(block=False)
251 | #per_site_features, per_site_base_qual, per_site_base_seq, per_site_ref_seq, per_site_ref_coordinates, per_site_label, read_name_array, ref_name_array
252 |
253 | mat.append(res[0])
254 | base_qual.append(res[1])
255 | base_seq.append(res[2])
256 | ref_seq.append(res[3])
257 | ref_coordinates.append(res[4])
258 | label.append(res[5])
259 | read_name.append(res[6])
260 | ref_name.append(res[7])
261 |
262 | read_count+=1
263 |
264 | if read_count%reads_per_chunk==0 and len(mat)>0:
265 | mat=np.vstack(mat)
266 | base_qual=np.vstack(base_qual)
267 | base_seq=np.vstack(base_seq).astype(np.int8)
268 | ref_seq=np.vstack(ref_seq).astype(np.int8)
269 | label=np.hstack(label).astype(np.float16)
270 | ref_coordinates=np.hstack(ref_coordinates)
271 | read_name=np.hstack(read_name)
272 | ref_name=np.hstack(ref_name)
273 |
274 | idx=np.random.permutation(np.arange(len(label)))
275 | mat=mat[idx]
276 | base_qual=base_qual[idx]
277 | base_seq=base_seq[idx]
278 | ref_seq=ref_seq[idx]
279 | label=label[idx]
280 | ref_coordinates=ref_coordinates[idx]
281 | read_name=read_name[idx]
282 | ref_name=ref_name[idx]
283 |
284 | print('%s: Number of reads processed = %d.' %(str(datetime.datetime.now()), read_count), flush=True)
285 |
286 |
287 | write_to_npz(output_file_path, mat, base_qual, base_seq, ref_seq, label, ref_coordinates, read_name, ref_name, window=params['window'], norm_type=params['norm_type'])
288 |
289 | chunk+=1
290 | output_file_path=os.path.join(output,'%s.features.%d.npz' %(params['prefix'], chunk))
291 | mat, base_qual, base_seq, ref_seq, label=[], [], [], [], []
292 | ref_coordinates, read_name, ref_name= [], [], []
293 |
294 | except queue.Empty:
295 | pass
296 |
297 | if read_count>0 and len(mat)>0:
298 | mat=np.vstack(mat)
299 | base_qual=np.vstack(base_qual)
300 | base_seq=np.vstack(base_seq).astype(np.int8)
301 | ref_seq=np.vstack(ref_seq).astype(np.int8)
302 | label=np.hstack(label).astype(np.float16)
303 | ref_coordinates=np.hstack(ref_coordinates)
304 | read_name=np.hstack(read_name)
305 | ref_name=np.hstack(ref_name)
306 |
307 | idx=np.random.permutation(np.arange(len(label)))
308 | mat=mat[idx]
309 | base_qual=base_qual[idx]
310 | base_seq=base_seq[idx]
311 | ref_seq=ref_seq[idx]
312 | label=label[idx]
313 | ref_coordinates=ref_coordinates[idx]
314 | read_name=read_name[idx]
315 | ref_name=ref_name[idx]
316 |
317 | print('%s: Number of reads processed = %d.' %(str(datetime.datetime.now()), read_count), flush=True)
318 |
319 |
320 | write_to_npz(output_file_path, mat, base_qual, base_seq, ref_seq, label, ref_coordinates, read_name, ref_name, params['window'], params['norm_type'])
321 |
322 | return
323 |
324 | def process(params, ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict, labelled_pos_list):
325 | base_map={'A':0, 'C':1, 'G':2, 'T':3, 'U':3}
326 |
327 | window=params['window']
328 | window_range=np.arange(-window,window+1)
329 | norm_type=params['norm_type']
330 |
331 | div_threshold=params['div_threshold']
332 | cigar_map={'M':0, '=':0, 'X':0, 'D':1, 'I':2, 'S':2,'H':2, 'N':1, 'P':4, 'B':4}
333 | cigar_pattern = r'\d+[A-Za-z]'
334 |
335 | seq_type=params['seq_type']
336 | ref_available=True if params['ref'] else False
337 |
338 | while True:
339 | if (signal_Q.empty() and input_event.is_set()):
340 | break
341 |
342 | try:
343 | data=signal_Q.get(block=False)
344 | signal, move, read_dict, align_data=data
345 |
346 | is_mapped, is_forward, ref_name, reference_start,reference_end, read_length=align_data
347 |
348 | fq=read_dict['seq']
349 | qual=read_dict['qual']
350 | sequence_length=len(fq)
351 | reverse= not is_forward
352 | fq=revcomp(fq) if reverse else fq
353 | qual=qual[::-1] if reverse else qual
354 |
355 | if is_mapped and True:
356 | cigar_tuples = np.array([(int(x[:-1]), cigar_map[x[-1]]) for x in re.findall(cigar_pattern, read_dict['cigar'])])
357 | ref_start=int(read_dict['ref_pos'])-1
358 | aligned_pairs=get_aligned_pairs(cigar_tuples, ref_start)
359 | else:
360 | continue
361 |
362 | init_pos_list_candidates, read_to_ref_pairs=get_candidates(fq, align_data, aligned_pairs, ref_pos_dict)
363 | init_pos_list_candidates=init_pos_list_candidates[(init_pos_list_candidates[:,0]>window)\
364 | &(init_pos_list_candidates[:,0]0 else init_pos_list_candidates
365 |
366 | if len(init_pos_list_candidates)==0:
367 |
368 | continue
369 |
370 | base_seq=np.array([base_map[x] for x in fq])
371 | ref_seq=ref_seq_dict[ref_name][:,1][read_to_ref_pairs[:, 1]][::-1] if reverse else ref_seq_dict[ref_name][:,0][read_to_ref_pairs[:, 1]]
372 |
373 |
374 | label_filter_idx=np.array([np.mean(ref_seq[candidate[0]-window: candidate[0]+window+1]!=\
375 | base_seq[candidate[0]-window: candidate[0]+window+1])<=div_threshold \
376 | for candidate in init_pos_list_candidates])
377 | pos_list_candidates=init_pos_list_candidates[label_filter_idx]
378 |
379 |
380 | if len(pos_list_candidates)==0:
381 | continue
382 |
383 | if not move[0]:
384 | try:
385 | tags={x.split(':')[0]:x for x in read_dict.pop('tags')}
386 | start=int(tags['ts'].split(':')[-1])
387 | mv=tags['mv'].split(',')
388 |
389 | stride=int(mv[1])
390 | move_table=np.fromiter(mv[2:], dtype=np.int8)
391 | move=(stride, start, move_table)
392 | read_dict['tags']=[x for x in tags.values() if x[:2] not in ['mv', 'ts', 'ML', 'MM']]
393 | except KeyError:
394 | print('Read:%s no move table or stride or signal start found' %read_dict['name'])
395 | total_unprocessed_reads.append(read_dict)
396 | continue
397 |
398 | base_qual=10**((33-np.array([ord(x) for x in qual]))/10)
399 | mean_qscore=-10*np.log10(np.mean(base_qual))
400 | base_qual=(1-base_qual)
401 |
402 | mat=get_events(signal, move, norm_type)
403 |
404 | if seq_type=='rna':
405 | mat=np.flip(mat,axis=0)
406 |
407 | per_site_features=np.array([mat[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
408 | per_site_base_qual=np.array([base_qual[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
409 | per_site_base_seq=np.array([base_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
410 | per_site_ref_seq=np.array([ref_seq[candidate[0]-window: candidate[0]+window+1] for candidate in pos_list_candidates])
411 | per_site_ref_coordinates=pos_list_candidates[:,1]
412 | per_site_label=np.array([labelled_pos_list[ref_name][1-is_forward][coord] for coord in per_site_ref_coordinates])
413 | read_name_array=np.array([read_dict['name'] for candidate in pos_list_candidates])
414 | ref_name_array=np.array([ref_name for candidate in pos_list_candidates])
415 |
416 | read_chunks=[per_site_features, per_site_base_qual, per_site_base_seq, per_site_ref_seq, per_site_ref_coordinates, per_site_label, read_name_array, ref_name_array]
417 |
418 | output_Q.put(read_chunks)
419 |
420 |
421 | except queue.Empty:
422 | pass
423 |
424 | return
425 |
426 | def get_input(params, signal_Q, output_Q, input_event):
427 | chrom_list=params['chrom']
428 |
429 | length_cutoff=params['length_cutoff']
430 |
431 | bam=params['bam']
432 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False)
433 |
434 | print('%s: Building BAM index.' %str(datetime.datetime.now()), flush=True)
435 | bam_index=pysam.IndexedReads(bam_file)
436 | bam_index.build()
437 | print('%s: Finished building BAM index.' %str(datetime.datetime.now()), flush=True)
438 |
439 | input_=params['input']
440 | signal_files= [input_] if os.path.isfile(input_) else Path(input_).rglob("*.%s" %params['file_type'])
441 |
442 | counter=0
443 | if params['file_type']=='fast5':
444 | guppy_group=params['guppy_group']
445 | for filename in signal_files:
446 | with get_fast5_file(filename, mode="r") as f5:
447 | for read in f5.get_reads():
448 | counter+=1
449 | if counter%10000==0:
450 | print('%s: Number of reads read = %d.' %(str(datetime.datetime.now()), counter), flush=True)
451 | if signal_Q.qsize()>10000:
452 | time.sleep(10)
453 |
454 | read_name=read.read_id
455 | try:
456 | read_iter=bam_index.find(read_name)
457 | for bam_read in read_iter:
458 |
459 | if bam_read.flag & 0x900==0 and bam_read.reference_name in chrom_list and bam_read.query_length>=length_cutoff:
460 | read_dict=bam_read.to_dict()
461 | signal=read.get_raw_data()
462 |
463 | if params['fast5_move']:
464 | segment=read.get_analysis_attributes(guppy_group)['segmentation']
465 | start=read.get_analysis_attributes('%s/Summary/segmentation' %segment)['first_sample_template']
466 | stride=read.get_summary_data(guppy_group)['basecall_1d_template']['block_stride']
467 | move_table=read.get_analysis_dataset('%s/BaseCalled_template' %guppy_group, 'Move')
468 | move=(stride, start, move_table)
469 | else:
470 | move=(None,None,None)
471 |
472 | align_data=(bam_read.is_mapped if params['ref'] else False,
473 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length)
474 | data=(signal, move, read_dict, align_data)
475 | signal_Q.put(data)
476 |
477 | except KeyError as error:
478 | continue
479 |
480 | else:
481 | move=(None,None,None)
482 | for filename in signal_files:
483 | with p5.Reader(filename) as reader:
484 | for read in reader.reads():
485 | counter+=1
486 | if counter%10000==0:
487 | print('%s: Number of reads processed = %d.' %(str(datetime.datetime.now()), counter), flush=True)
488 |
489 | if signal_Q.qsize()>10000:
490 | time.sleep(10)
491 |
492 | read_name=str(read.read_id)
493 | try:
494 | read_iter=bam_index.find(read_name)
495 | for bam_read in read_iter:
496 | read_dict=bam_read.to_dict()
497 | if bam_read.flag & 0x900==0 and bam_read.reference_name in chrom_list and bam_read.query_length>=length_cutoff:
498 | read_dict=bam_read.to_dict()
499 | signal=read.signal
500 | align_data=(bam_read.is_mapped if params['ref'] else False,
501 | bam_read.is_forward, bam_read.reference_name, bam_read.reference_start, bam_read.reference_end, bam_read.query_length)
502 | data=(signal, move, read_dict, align_data)
503 | signal_Q.put(data)
504 |
505 | except KeyError:
506 | continue
507 |
508 | input_event.set()
509 | return
510 |
511 | def call_manager(params):
512 | bam=params['bam']
513 | bam_file=pysam.AlignmentFile(bam,'rb',check_sq=False)
514 | header_dict=bam_file.header.to_dict()
515 |
516 | print('%s: Getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True)
517 |
518 | ref_seq_dict={}
519 | ref_pos_dict={}
520 |
521 | labelled_pos_list={}
522 |
523 | if params['pos_list']:
524 | labelled_pos_list=get_pos(params['pos_list'])
525 | params['chrom']=[x for x in params['chrom'] if x in labelled_pos_list.keys()]
526 |
527 | motif_label=params['motif_label']
528 | _=get_ref_to_num('ACGT')
529 | ref_seq_dict={}
530 |
531 | with mp.Pool(processes=params['threads']) as pool:
532 | res=pool.map(get_ref_info, zip(repeat(params), params['chrom']))
533 | for r in res:
534 | chrom, seq_array, fwd_pos_array, rev_pos_array=r
535 | ref_seq_dict[chrom]=seq_array
536 |
537 | if params['pos_list']:
538 | ref_pos_dict[chrom]=(np.array(sorted(list(labelled_pos_list[chrom][0].keys()))).astype(int), np.array(sorted(list(labelled_pos_list[chrom][1].keys()))).astype(int))
539 |
540 | else:
541 | ref_pos_dict[chrom]=(fwd_pos_array, rev_pos_array)
542 | labelled_pos_list[chrom]={0:{}, 1:{}}
543 | for strand in [0,1]:
544 | for pos in ref_pos_dict[chrom][strand]:
545 | labelled_pos_list[chrom][strand][pos]=float(motif_label)
546 |
547 |
548 |
549 | print('%s: Finished getting motif positions from the reference.' %str(datetime.datetime.now()), flush=True)
550 |
551 | pmanager = mp.Manager()
552 | signal_Q = pmanager.Queue()
553 | output_Q = pmanager.Queue()
554 | process_event=pmanager.Event()
555 | input_event=pmanager.Event()
556 |
557 | handlers = []
558 |
559 | input_process = mp.Process(target=get_input, args=(params, signal_Q, output_Q, input_event))
560 | input_process.start()
561 | handlers.append(input_process)
562 |
563 | for hid in range(max(1,params['threads']-1)):
564 | p = mp.Process(target=process, args=(params, ref_pos_dict, signal_Q, output_Q, input_event, ref_seq_dict, labelled_pos_list));
565 | p.start();
566 | handlers.append(p);
567 |
568 | output_process=mp.Process(target=get_output, args=(params, output_Q, process_event));
569 | output_process.start();
570 |
571 | for job in handlers:
572 | job.join()
573 |
574 | process_event.set()
575 | output_process.join()
576 |
577 | return
578 |
579 | if __name__ == '__main__':
580 |
581 | t=time.time()
582 |
583 | print('%s: Starting feature generation.' %str(datetime.datetime.now()))
584 |
585 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
586 |
587 | parser.add_argument("--bam", help='Path to bam file', type=str, required=True)
588 | parser.add_argument("--window", help='Number of bases before or after the base of interest to include in the model. Total number of bases included in teh model will be 2xwindow+1.', type=int, default=10)
589 | parser.add_argument("--prefix", help='Prefix for the output files',type=str, default='output')
590 | parser.add_argument("--input", help='Path to folder containing POD5 or FAST5 files. Files will be recusrviely searched.', type=str, required=True)
591 |
592 | parser.add_argument("--output", help='Path to folder where features will be stored', type=str, required=True)
593 |
594 | parser.add_argument("--threads", help='Number of processors to use',type=int, default=1)
595 |
596 | parser.add_argument("--div_threshold", help='Divergence Threshold.',type=float, default=0.25)
597 |
598 | parser.add_argument("--reads_per_chunk", help='reads_per_chunk',type=int, default=100000)
599 |
600 | parser.add_argument("--ref", help='Path to reference FASTA file to anchor methylation calls to reference loci. If no reference is provided, only the motif loci on reads will be used.', type=str)
601 |
602 | parser.add_argument("--pos_list", help='Text file containing a list of positions to generate features for. Use either --pos_list or --motif to specify how to choose loci for feature generation, but not both. The file should be whitespace separated with the following information on each line: chrom pos strand label. The position is 0-based reference coordinate, strand is + for forward and - for negative strand; label is 1 for mod, 0 for unmod).', type=str)
603 |
604 | parser.add_argument("--file_type", help='Specify whether the signal is in FAST5 or POD5 file format. If POD5 file is used, then move table must be in BAM file.',choices=['fast5','pod5'], type=str,required=True)
605 |
606 | parser.add_argument("--seq_type", help='Specify DNA or direct RNA sequencing.',choices=['dna','rna'], type=str,required=True)
607 |
608 | parser.add_argument("--guppy_group", help='Name of the guppy basecall group',type=str, default='Basecall_1D_000')
609 | parser.add_argument("--chrom", nargs='*', help='A space/whitespace separated list of contigs, e.g. chr3 chr6 chr22. If not list is provided then all chromosomes in the reference are used.')
610 | parser.add_argument("--length_cutoff", help='Minimum cutoff for read length',type=int, default=0)
611 | parser.add_argument("--fast5_move", help='Use move table from FAST5 file instead of BAM file. If this flag is set, specify a basecall group for FAST5 file using --guppy_group parameter and ensure that the FAST5 files contains move table.', default=False, action='store_true')
612 |
613 | parser.add_argument("--motif", help='Motif for generating features followed by zero-based indices of nucleotides within the motif to generate features for. The motif and each index listed should be separated by whitespace, e.g. "--motif CGCG 0 2". Use either --pos_list or --motif to specify how to choose loci for feature generation, but not both. Features will be generated for all loci of the read that map to a reference sequence that matches the motif. Multiple indices can be specified but they should refer to the same nucleotide letter. If you use --motif, it is assumed that all loci have the same modification label and you need to specify the label using --motif_label.', nargs='*')
614 |
615 | parser.add_argument("--motif_label", help='Modification label for the motif. 0 is for unmodified and 1 is for modified.',type=int, choices=[0,1])
616 |
617 | parser.add_argument("--norm_type", help='How to normalize read signal.',type=str, choices=['mad', 'standard'], default='mad')
618 |
619 | args = parser.parse_args()
620 |
621 | if not args.output:
622 | args.output=os.getcwd()
623 |
624 | os.makedirs(args.output, exist_ok=True)
625 |
626 |
627 | if args.chrom:
628 | chrom_list=args.chrom
629 | else:
630 | chrom_list=pysam.Samfile(args.bam).references
631 |
632 |
633 | if args.motif and len(args.motif)>0:
634 | if args.pos_list is not None:
635 | print('Use either --motif or --pos_list but not both', flush=True)
636 | sys.exit(3)
637 |
638 | if args.motif_label is None:
639 | print('--motif_label should be specified with --motif option', flush=True)
640 | sys.exit(3)
641 |
642 | motif_seq, exp_motif_seq, motif_ind, valid_motif=motif_check(args.motif)
643 | if not valid_motif:
644 | sys.exit(3)
645 |
646 | else:
647 | motif_seq=None
648 | motif_ind=None
649 | exp_motif_seq=None
650 |
651 | if args.pos_list is None:
652 | print('Use either --motif or --pos_list', flush=True)
653 | sys.exit(3)
654 |
655 | params=dict(bam=args.bam,
656 | seq_type=args.seq_type,
657 | window=args.window,
658 | pos_list=args.pos_list,
659 | ref=args.ref,
660 | input=args.input,
661 | norm_type=args.norm_type,
662 | motif_seq=motif_seq,
663 | motif_ind=motif_ind,
664 | exp_motif_seq=exp_motif_seq,
665 | motif_label=args.motif_label,
666 | file_type=args.file_type,
667 | guppy_group=args.guppy_group,
668 | fast5_move=args.fast5_move,
669 | chrom=chrom_list, threads=args.threads,
670 | length_cutoff=args.length_cutoff,
671 | output=args.output, prefix=args.prefix,
672 | div_threshold=args.div_threshold, reads_per_chunk=args.reads_per_chunk)
673 |
674 | print(args)
675 |
676 |
677 | with open(os.path.join(args.output,'args'),'w') as file:
678 | file.write('Command: python %s\n\n\n' %(' '.join(sys.argv)))
679 | file.write('------Parameters Used For Running DeepMod2------\n')
680 | for k in vars(args):
681 | file.write('{}: {}\n'.format(k,vars(args)[k]) )
682 | call_manager(params)
683 | print('\n%s: Time elapsed=%.4fs' %(str(datetime.datetime.now()),time.time()-t))
684 |
--------------------------------------------------------------------------------
/train/train_models.py:
--------------------------------------------------------------------------------
1 | import math, time, argparse, re, os, sys
2 | import functools, itertools, random
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn import TransformerEncoder, TransformerEncoderLayer
7 | from torch import optim, Tensor
8 | import numpy as np
9 | import multiprocessing as mp
10 | import multiprocessing as mp
11 | import queue
12 | import pickle
13 | from utils import *
14 |
15 | class OneHotEncode(nn.Module):
16 | def __init__(self, num_classes: int):
17 | super().__init__()
18 | self.num_classes=num_classes
19 | def forward(self, x: Tensor) -> Tensor:
20 | return F.one_hot(x, self.num_classes)
21 |
22 | class RefandReadEmbed(nn.Module):
23 | def __init__(self, embedding_dim, embedding_type):
24 | super().__init__()
25 |
26 | self.embedding_depth=0
27 |
28 | if embedding_type=='one_hot':
29 | self.read_emb=OneHotEncode(4)
30 | self.embedding_depth+=4
31 |
32 | self.ref_emb=OneHotEncode(5)
33 | self.embedding_depth+=5
34 |
35 | elif embedding_type=='learnable':
36 | self.read_emb=nn.Embedding(4, embedding_dim)
37 | self.embedding_depth+=embedding_dim
38 |
39 | self.ref_emb=nn.Embedding(5, embedding_dim)
40 | self.embedding_depth+=embedding_dim
41 |
42 | def forward(self, batch_base_seq, batch_ref_seq):
43 | batch_base_seq_emb=self.read_emb(batch_base_seq)
44 | batch_ref_seq_emb=self.ref_emb(batch_ref_seq)
45 |
46 | return torch.cat((batch_base_seq_emb, batch_ref_seq_emb), 2)
47 |
48 | class ReadEmbed(nn.Module):
49 | def __init__(self, embedding_dim, embedding_type):
50 | super().__init__()
51 |
52 | self.embedding_depth=0
53 |
54 | if embedding_type=='one_hot':
55 | self.read_emb=OneHotEncode(4)
56 | self.embedding_depth+=4
57 |
58 | elif embedding_type=='learnable':
59 | self.read_emb=nn.Embedding(4, embedding_dim)
60 | self.embedding_depth+=embedding_dim
61 |
62 |
63 | def forward(self, batch_base_seq, batch_ref_seq):
64 | batch_base_seq_emb=self.read_emb(batch_base_seq)
65 |
66 | return batch_base_seq_emb
67 |
68 | class SeqEmbed(nn.Module):
69 | def __init__(self, embedding_dim, embedding_type, include_ref):
70 | super().__init__()
71 |
72 | self.embedding_depth=0
73 |
74 | if include_ref:
75 | self.seq_emb=RefandReadEmbed(embedding_dim, embedding_type)
76 |
77 | else:
78 | self.seq_emb=ReadEmbed(embedding_dim, embedding_type)
79 |
80 | self.embedding_depth=self.seq_emb.embedding_depth
81 |
82 | def forward(self, batch_base_seq, batch_ref_seq):
83 | return self.seq_emb(batch_base_seq, batch_ref_seq)
84 |
85 | class PositionalEncoding(nn.Module):
86 | def __init__(self, pe_dim: int, max_len: int):
87 | super().__init__()
88 | position = torch.arange(max_len).unsqueeze(1)
89 | div_term = torch.exp(torch.arange(0, pe_dim, 2) * (-math.log(pe_dim) / (pe_dim)))
90 | pe = torch.zeros(1, max_len, pe_dim)
91 | pe[0,:, 0::2] = torch.sin(position * div_term)
92 | pe[0,:, 1::2] = torch.cos(position * div_term)
93 | self.register_buffer('pe', pe)
94 |
95 | def forward(self, x: Tensor) -> Tensor:
96 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1))
97 | x = torch.cat((x, x_pos),2)
98 | return x
99 |
100 | class PositionalEmbedding(nn.Module):
101 | def __init__(self, pe_dim: int, max_len: int):
102 | super().__init__()
103 | pos=torch.arange(max_len)
104 | self.register_buffer('pos', pos)
105 | self.pe=nn.Embedding(max_len, pe_dim)
106 |
107 | def forward(self, x: Tensor) -> Tensor:
108 | x_pos=self.pe(self.pos)
109 | x_pos=torch.Tensor.repeat(x_pos,(x.size(0),1,1))
110 | x = torch.cat((x, x_pos),2)
111 | return x
112 |
113 | class PositionalParameter(nn.Module):
114 | def __init__(self, pe_dim: int, max_len: int):
115 | super().__init__()
116 |
117 | self.pe=torch.nn.Parameter(torch.randn(max_len, pe_dim))
118 |
119 | def forward(self, x: Tensor) -> Tensor:
120 | x_pos=torch.Tensor.repeat(self.pe,(x.size(0),1,1))
121 | x = torch.cat((x, x_pos),2)
122 | return x
123 |
124 | class ClassifierMiddle(nn.Module):
125 | def __init__(self, in_dim: int, num_fc: int, model_len: int):
126 | super().__init__()
127 | self.mid = model_len//2
128 | self.fc = nn.Linear(in_dim, num_fc)
129 | self.out = nn.Linear(num_fc,1)
130 |
131 | def forward(self, x):
132 | x = F.relu(self.fc(x[:,self.mid, :]))
133 | x=self.out(x)
134 | return x
135 |
136 | class ClassifierAll(nn.Module):
137 | def __init__(self, in_dim: int, num_fc: int):
138 | super().__init__()
139 | self.fc = nn.Linear(in_dim, num_fc)
140 | self.out = nn.Linear(num_fc,1)
141 |
142 | def forward(self, x):
143 | x = torch.flatten(x, 1)
144 | x = F.relu(self.fc(x))
145 | x=self.out(x)
146 | return x
147 |
148 | class BiLSTM(nn.Module):
149 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, fc_type):
150 | super(BiLSTM, self).__init__()
151 |
152 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref)
153 | self.model_len=model_dims[0]
154 | self.model_depth=model_dims[1]+self.emb.embedding_depth
155 |
156 | self.bilstm = nn.LSTM(input_size=self.model_depth, hidden_size=dim_feedforward, num_layers=num_layers, bidirectional=True, batch_first = True)
157 |
158 | if fc_type=='middle':
159 | self.classifier=ClassifierMiddle(in_dim=dim_feedforward*2, num_fc=num_fc, model_len=self.model_len)
160 |
161 | else:
162 | self.classifier=ClassifierAll(in_dim=self.model_len*dim_feedforward*2, num_fc=num_fc)
163 |
164 | def forward(self, batch_x, batch_base_seq, batch_ref_seq):
165 | seq_emb=self.emb(batch_base_seq, batch_ref_seq)
166 | x=torch.cat((batch_x, seq_emb), 2)
167 | x, _=self.bilstm(x)
168 | x = self.classifier(x)
169 |
170 | return x
171 |
172 | class TransformerModel(nn.Module):
173 | def __init__(self, model_dims, num_layers, dim_feedforward, num_fc, embedding_dim, embedding_type, include_ref, pe_dim, nhead, pe_type, fc_type):
174 | super(TransformerModel, self).__init__()
175 |
176 | self.emb=SeqEmbed(embedding_dim, embedding_type, include_ref)
177 | self.model_len=model_dims[0]
178 |
179 | if pe_type=='fixed':
180 | self.pe_block=PositionalEncoding(pe_dim=pe_dim, max_len=self.model_len)
181 |
182 | elif pe_type=='embedding':
183 | self.pe_block=PositionalEmbedding(pe_dim=pe_dim, max_len=self.model_len)
184 |
185 | elif pe_type=='parameter':
186 | self.pe_block=PositionalParameter(pe_dim=pe_dim, max_len=self.model_len)
187 |
188 | self.model_depth=model_dims[1]+self.emb.embedding_depth+pe_dim
189 | self.pad_length=math.ceil(self.model_depth/nhead)*nhead-self.model_depth
190 | pad=torch.zeros(1,self.model_len, self.pad_length)
191 | self.register_buffer('pad', pad)
192 | self.model_depth+=self.pad_length
193 |
194 | encoder_layer = nn.TransformerEncoderLayer(d_model=self.model_depth, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0, batch_first = True)
195 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
196 |
197 | if fc_type=='middle':
198 | self.classifier=ClassifierMiddle(in_dim=self.model_depth, num_fc=num_fc, model_len=self.model_len)
199 |
200 | else:
201 | self.classifier=ClassifierAll(in_dim=self.model_len*self.model_depth, num_fc=num_fc)
202 |
203 | def forward(self, batch_x, batch_base_seq, batch_ref_seq):
204 | seq_emb=self.emb(batch_base_seq, batch_ref_seq)
205 | x=torch.cat((batch_x, seq_emb), 2)
206 | x=self.pe_block(x)
207 | x_pad=torch.Tensor.repeat(self.pad,(x.size(0),1,1))
208 | x = torch.cat((x, x_pad),2)
209 |
210 | x=self.transformer_encoder(x)
211 | x = self.classifier(x)
212 |
213 | return x
214 |
215 |
216 |
217 | def train(training_dataset, validation_dataset, validation_type, validation_fraction, model_config, epochs, prefix, retrain, batch_size, args_str, seed):
218 | print('Starting training.' , flush=True)
219 | torch.manual_seed(seed)
220 | model_type = model_config['model_type']
221 | model_save_path = model_config.pop('model_save_path')
222 |
223 | if torch.cuda.is_available():
224 | dev = "cuda:0"
225 | else:
226 | dev = "cpu"
227 |
228 | weight_counts=np.array([np.sum(np.eye(2)[np.load(f)['label'].astype(int)],axis=0) for f in itertools.chain.from_iterable(training_dataset)])
229 | weight_counts=np.sum(weight_counts,axis=0)
230 |
231 |
232 | if model_config['weights']=='equal':
233 | pos_weight=torch.Tensor(np.array(1.0))
234 |
235 | elif model_config['weights']=='auto':
236 | pos_weight=torch.Tensor(np.array(weight_counts[0]/weight_counts[1]))
237 |
238 | else:
239 | pos_weight=torch.Tensor(np.array(float(model_config['weights'])))
240 |
241 | print('Number of Modified Instances={}\nNumber of Un-Modified Instances={}\nPositive Label Weight={}\n'.format(weight_counts[1],weight_counts[0],pos_weight), flush=True)
242 |
243 | if model_type=='bilstm':
244 | net = BiLSTM(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \
245 | dim_feedforward=model_config['dim_feedforward'], \
246 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \
247 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'], \
248 | fc_type=model_config['fc_type']);
249 |
250 | elif model_type=='transformer':
251 | net = TransformerModel(model_dims=model_config['model_dims'], num_layers=model_config['num_layers'], \
252 | dim_feedforward=model_config['dim_feedforward'], \
253 | num_fc=model_config['num_fc'], embedding_dim=model_config['embedding_dim'], \
254 | embedding_type=model_config['embedding_type'], include_ref=model_config['include_ref'],\
255 | pe_dim=model_config['pe_dim'], nhead=model_config['nhead'], \
256 | pe_type=model_config['pe_type'], fc_type=model_config['fc_type']);
257 |
258 | net.to(dev);
259 |
260 | optimizer = optim.Adam(net.parameters(), lr= model_config['lr'], weight_decay=model_config['l2_coef'])
261 |
262 | if retrain:
263 | checkpoint = torch.load(retrain)
264 | net.load_state_dict(checkpoint['model_state_dict'])
265 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
266 |
267 |
268 | model_details=str(net.to(torch.device(dev)))
269 | print(model_details, flush=True)
270 | num_params=sum(p.numel() for p in net.parameters())
271 | print('# Parameters=', num_params, flush=True)
272 |
273 | config_path=os.path.join(model_save_path, '%s.cfg' %prefix)
274 |
275 | with open(config_path, 'wb') as handle:
276 | pickle.dump(model_config, handle, protocol=pickle.HIGHEST_PROTOCOL)
277 |
278 | log_file_path=os.path.join(model_save_path, '%s.log' %prefix)
279 |
280 | train_w_wo_ref=model_config['train_w_wo_ref']
281 | include_ref=model_config['include_ref']
282 |
283 | if include_ref and train_w_wo_ref:
284 | list_of_evals=['Normal', 'Without_Ref']
285 | else:
286 | list_of_evals=['Normal']
287 |
288 | dummy_ref_seq=(4+torch.zeros(batch_size, model_config["model_dims"][0])).type(torch.LongTensor).to(dev)
289 |
290 | with open(log_file_path,'w') as log_file:
291 | log_file.write(args_str)
292 | log_file.write('\n# Parameters=%d\n' %num_params)
293 | log_file.write(model_details)
294 |
295 | for j in range(epochs):
296 | net.train()
297 |
298 | metrics_train={g:{'TP':0,'FP':0,'FN':0,'loss':0,'len':0,'true':0} for g in list_of_evals}
299 | metrics_test={g:{'TP':0,'FP':0,'FN':0,'loss':0,'len':0,'true':0} for g in list_of_evals}
300 |
301 | t=time.time()
302 |
303 | train_gen=generate_batches_mixed_can_mod(training_dataset, validation_type, validation_fraction, data_type="train", batch_size=batch_size)
304 |
305 | for batch in train_gen:
306 | batch_x, batch_base_seq, batch_ref_seq, batch_y =batch
307 | batch_x, batch_base_seq, batch_ref_seq, batch_y=batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev), batch_y.to(dev)
308 |
309 | optimizer.zero_grad()
310 | score= net(batch_x, batch_base_seq, batch_ref_seq)
311 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight)
312 |
313 | loss.backward()
314 | optimizer.step()
315 |
316 | get_metrics(metrics_train,'Normal', batch_y, score, loss)
317 |
318 | if include_ref and train_w_wo_ref:
319 | dummy_batch_ref_seq=dummy_ref_seq[:batch_ref_seq.size(0)]
320 | optimizer.zero_grad()
321 | score= net(batch_x, batch_base_seq, dummy_batch_ref_seq)
322 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight)
323 |
324 | loss.backward()
325 | optimizer.step()
326 | get_metrics(metrics_train,'Without_Ref', batch_y, score, loss)
327 |
328 | with torch.no_grad():
329 | net.eval()
330 |
331 | if validation_type=='split':
332 | test_gen=generate_batches(list(itertools.chain.from_iterable(training_dataset)), validation_type, validation_fraction, data_type="test", batch_size=batch_size)
333 |
334 | else:
335 | test_gen=generate_batches(validation_dataset, validation_type, validation_fraction, data_type="test", batch_size=batch_size)
336 |
337 | for batch in test_gen:
338 | batch_x, batch_base_seq, batch_ref_seq, batch_y = batch
339 | batch_x, batch_base_seq, batch_ref_seq, batch_y=batch_x.to(dev), batch_base_seq.to(dev), batch_ref_seq.to(dev), batch_y.to(dev)
340 |
341 | score= net(batch_x, batch_base_seq, batch_ref_seq)
342 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight)
343 |
344 | get_metrics(metrics_test,'Normal', batch_y, score, loss)
345 |
346 | if include_ref and train_w_wo_ref:
347 | dummy_batch_ref_seq=dummy_ref_seq[:batch_ref_seq.size(0)]
348 | score= net(batch_x, batch_base_seq, dummy_batch_ref_seq)
349 | loss = torch.nn.functional.binary_cross_entropy_with_logits(score, batch_y,pos_weight=pos_weight)
350 |
351 | get_metrics(metrics_test,'Without_Ref', batch_y, score, loss)
352 |
353 | train_str, _ = get_stats(metrics_train, 'Training')
354 | test_str, total_test_acc = get_stats(metrics_test, 'Testing')
355 |
356 | epoch_log='\n\nEpoch %d: #Train=%d #Test=%d Time=%.4f\n%s\n\n%s'\
357 | %(j+1, sum(x['len'] for x in metrics_train.values()), sum(x['len'] for x in metrics_test.values()), time.time()-t,
358 | train_str, test_str)
359 | print(epoch_log, flush=True)
360 | log_file.write(epoch_log)
361 | log_file.flush()
362 | os.fsync(log_file.fileno())
363 |
364 | model_path=os.path.join(model_save_path, 'model.epoch%d.%.4f' %(j+1, total_test_acc))
365 | torch.save({
366 | 'epoch': j+1,
367 | 'model_state_dict': net.state_dict(),
368 | 'optimizer_state_dict': optimizer.state_dict()}, model_path)
369 |
370 | return net
371 |
372 | if __name__=='__main__':
373 | start_time=time.time()
374 |
375 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
376 |
377 | parser.add_argument("--mixed_training_dataset", nargs='*', help='Training dataset with mixed labels. A whitespace separated list of folders containing .npz files or paths to individual .npz files.')
378 |
379 | parser.add_argument("--can_training_dataset", nargs='*', help='Training dataset with unmodified or canonical base labels. A whitespace separated list of folders containing .npz files or paths to individual .npz files.')
380 | parser.add_argument("--mod_training_dataset", nargs='*', help='Training dataset with modified labels. A whitespace separated list of folders containing .npz files or paths to individual .npz files.')
381 |
382 | parser.add_argument("--validation_type",choices=['split','dataset'], help='How the validation is performed. "split" means that a fraction of training dataset, specified by --validation_fraction, will be used for validation. "dataset" means that additional validation dataset is provided via --validation_input parameter.',default="split")
383 | parser.add_argument("--validation_fraction", help='Fraction of training dataset to use for validation if --validation_type is set to "split", otherwise ignored.', type=float, default=0.2)
384 | parser.add_argument("--validation_dataset",nargs='*', help='Validation dataset if --validation_type is set to dataset. A whitespace separated list of folders containing .npz files or paths to individual .npz files.')
385 |
386 | parser.add_argument("--prefix", help='Prefix name for the model checkpoints', default='model')
387 |
388 | parser.add_argument("--weights", help='Weight of positive(modified) label used in binary cross entropy loss, the negative(unmodified) label will always have a fixed weight of 1. Higher weight for modified labels will favor recall and lower weight will favor precision. Choices are "equal", "auto" or your can specify the weight of positive(modified) labels. "equal" assigns a weight of 1 to modified labels, "auto" assigns a weight=num_negative_samples/num_positive_samples to modified labels.', default='equal')
389 |
390 |
391 | parser.add_argument("--model_save_path", help='Folder path for saving model checkpoints')
392 | parser.add_argument("--epochs", help='Number of total epochs', default=100, type=int)
393 | parser.add_argument("--batch_size", help='Batch Size', default=256, type=int)
394 | parser.add_argument("--retrain", help='Path to a model for retraining', default=None)
395 |
396 | parser.add_argument("--fc_type", help='Type of full connection to use in the classifier.', type=str, default='all', choices=['middle', 'all'])
397 | parser.add_argument("--model_type", help='Type of model to use', type=str, choices=['bilstm', 'transformer'])
398 |
399 | parser.add_argument("--num_layers", help='Number of transformer encoder or BiLSTM layers', type=int, default=3)
400 | parser.add_argument("--dim_feedforward", help='Dimension of feedforward layers in transformer encoder or size of hidden units in BiLSTM layers', type=int, default=100)
401 | parser.add_argument("--num_fc", help='Size of fully connected layer between encoder/BiLSTM and classifier', type=int, default=16)
402 | parser.add_argument("--embedding_dim", help='Size of embedding dimension for read and reference bases', type=int, default=4)
403 | parser.add_argument("--embedding_type", help='Type of embedding for bases', type=str, choices=['learnable', 'one_hot'], default='one_hot')
404 | parser.add_argument("--pe_dim", help='Dimension for positional encoding/embedding for transformer model.', type=int, default=16)
405 | parser.add_argument("--pe_type", help='Type of positional encoding/embedding for transformer model. fixed is sinusoid, embedding is is dictionary lookup, parameter weight matrix.', type=str, choices=['fixed', 'embedding', 'parameter'], default='fixed')
406 | parser.add_argument("--nhead", help='Number of self-attention heads in transformer encoder for transformer model.', type=int, default=4)
407 | parser.add_argument("--include_ref", help='Whether to include reference sequence as features. Recommended.', default=False, action='store_true')
408 | parser.add_argument("--train_w_wo_ref", help='Include examples with reference and without reference sequence. Recommended if you will be using referenve free modification detection.', default=False, action='store_true')
409 |
410 | parser.add_argument("--lr", help='Learning rate', type=float, default=1e-4)
411 | parser.add_argument("--l2_coef", help='L2 regularization coefficient', type=float, default=1e-5)
412 | parser.add_argument("--seed", help='Random seed to use in pytorch for reproducibility or reinitialization of weights', default=None)
413 |
414 | args = parser.parse_args()
415 |
416 | os.makedirs(args.model_save_path, exist_ok=True)
417 |
418 | mixed_training_dataset=get_files(args.mixed_training_dataset)
419 | can_training_dataset=get_files(args.can_training_dataset)
420 | mod_training_dataset=get_files(args.mod_training_dataset)
421 |
422 | validation_dataset=get_files(args.validation_dataset)
423 | validation_type=args.validation_type
424 | validation_fraction=args.validation_fraction
425 |
426 | valid_data, window, norm_type, strides_per_base, model_depth, full_signal = check_training_files(mixed_training_dataset, can_training_dataset,\
427 | mod_training_dataset, validation_dataset)
428 |
429 | if not valid_data:
430 | sys.exit(3)
431 |
432 | model_len=strides_per_base*(2*window+1)#2*window+1
433 |
434 | model_config = dict(model_dims=(model_len,model_depth+1),window=window, model_type=args.model_type,
435 | num_layers=args.num_layers, dim_feedforward=args.dim_feedforward,
436 | num_fc=args.num_fc, embedding_dim=args.embedding_dim,
437 | embedding_type=args.embedding_type, include_ref=args.include_ref,
438 | pe_dim=args.pe_dim, nhead=args.nhead, pe_type=args.pe_type,
439 | l2_coef=args.l2_coef, lr=args.lr, model_save_path=args.model_save_path, fc_type=args.fc_type,
440 | train_w_wo_ref=args.train_w_wo_ref, weights=args.weights, norm_type=norm_type, full_signal=full_signal, strides_per_base=strides_per_base)
441 |
442 | args_dict=vars(args)
443 | args_str=''.join('%s: %s\n' %(k,str(v)) for k,v in args_dict.items())
444 | print(args_str, flush=True)
445 |
446 | seed =random.randint(0, 0xffff_ffff_ffff_ffff) if args.seed is None else int(args.seed)
447 |
448 | training_dataset = [mixed_training_dataset, can_training_dataset, mod_training_dataset]
449 |
450 | with open(os.path.join(args.model_save_path,'args'),'w') as file:
451 | file.write('Command: python %s\n\n\n' %(' '.join(sys.argv)))
452 | file.write('------Parameters Used For Running DeepMod2------\n')
453 | for k in vars(args):
454 | file.write('{}: {}\n'.format(k,vars(args)[k]) )
455 |
456 | res=train(training_dataset, validation_dataset, validation_type, validation_fraction, model_config, epochs=args.epochs,prefix=args.prefix, retrain=args.retrain, batch_size=args.batch_size, args_str=args_str, seed=seed)
457 |
458 | print('Time taken=%.4f' %(time.time()-start_time), flush=True)
459 |
--------------------------------------------------------------------------------
/train/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch, os, random
3 | from pathlib import Path
4 | from sklearn.model_selection import train_test_split
5 | import sklearn
6 | import itertools
7 |
8 | def get_files(input_list):
9 | out_list=[]
10 | if type(input_list)==type(None):
11 | return out_list
12 | for item in input_list:
13 |
14 | if os.path.isdir(item):
15 | out_list.extend(list(Path(item).rglob("*.npz")))
16 |
17 | elif item[-4:]=='.npz':
18 | out_list.append(item)
19 |
20 | random.seed(0)
21 | random.shuffle(out_list)
22 | return out_list
23 |
24 | def read_from_file(file_name, validation_type, validation_fraction, data_type, get_pos=False):
25 | data=np.load(file_name)
26 | mat=data['mat']
27 | base_qual=data['base_qual']
28 |
29 | features=np.dstack((mat, base_qual[:,:,np.newaxis]))
30 | features=torch.Tensor(features)
31 |
32 | base_seq=torch.Tensor(data['base_seq']).type(torch.LongTensor)
33 | ref_seq=torch.Tensor(data['ref_seq']).type(torch.LongTensor)
34 |
35 | labels=torch.Tensor(data['label'][:,np.newaxis])
36 |
37 | if get_pos:
38 | pos=data['ref_coordinates']
39 | chrom=data['ref_name']
40 | pos_data=np.vstack([chrom, pos]).T
41 |
42 | return features, base_seq, ref_seq, labels, pos_data
43 |
44 | if validation_type=='split':
45 | features_train, features_test, base_seq_train, base_seq_test, \
46 | ref_seq_train, ref_seq_test, labels_train, labels_test=\
47 | train_test_split(features, base_seq, ref_seq, labels, test_size=validation_fraction, random_state=42)
48 | if data_type=='train':
49 | return features_train, base_seq_train, ref_seq_train, labels_train
50 | else:
51 | return features_test, base_seq_test, ref_seq_test, labels_test
52 |
53 | else:
54 | return features, base_seq, ref_seq, labels
55 |
56 | def check_training_files(mixed_training_dataset, can_training_dataset,\
57 | mod_training_dataset, validation_dataset):
58 | norm_type=[str(np.load(file)['norm_type']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\
59 | mod_training_dataset, validation_dataset])]
60 | window=[int(np.load(file)['window']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\
61 | mod_training_dataset, validation_dataset])]
62 | strides_per_base=[int(np.load(file)['strides_per_base']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\
63 | mod_training_dataset, validation_dataset])]
64 |
65 | model_depth=[int(np.load(file)['model_depth']) for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset,\
66 | mod_training_dataset, validation_dataset])]
67 | full_signal=np.unique([np.load(file)['full_signal'] for file in itertools.chain.from_iterable([mixed_training_dataset, can_training_dataset, mod_training_dataset, validation_dataset])])
68 |
69 | if len(set(window))==1 and len(set(norm_type))==1 and len(set(strides_per_base))==1 and len(full_signal)==1:
70 | return True, window[0], norm_type[0], strides_per_base[0], model_depth[0], full_signal[0]
71 |
72 | elif len(set(window))>1:
73 | print('Inconsistent dataset with multiple window sizes')
74 |
75 | elif len(set(norm_type))>1:
76 | print('Inconsistent dataset with multiple normalization types')
77 |
78 | elif len(set(strides_per_base))>1:
79 | print('Inconsistent dataset with multiple strides_per_base')
80 |
81 | return False, window, norm_type, strides_per_base, model_depth, full_signal
82 |
83 | def generate_batches(files, validation_type, validation_fraction, data_type, batch_size=512):
84 | counter = 0
85 |
86 | print_freq=max(1, len(files)//10)
87 |
88 | while counter0])
122 | batch_base_seq=torch.vstack([d[1] for d in multi_batch_data if len(d[1])>0])
123 | batch_ref_seq=torch.vstack([d[2] for d in multi_batch_data if len(d[2])>0])
124 | batch_y=torch.vstack([d[3] for d in multi_batch_data if len(d[3])>0])
125 |
126 | yield batch_x, batch_base_seq, batch_ref_seq, batch_y
127 |
128 | def get_stats(metrics_dict, dtype):
129 |
130 | loss_str='{} Loss:'.format(dtype)
131 | acc_str='{} Accuracy:'.format(dtype)
132 | prec_str='{} Precision:'.format(dtype)
133 | rec_str='{} Recall:'.format(dtype)
134 | f1_str='{} F1:'.format(dtype)
135 |
136 | for g in sorted(metrics_dict.keys()):
137 |
138 | acc=metrics_dict[g]['true']/max(1,metrics_dict[g]['len'])
139 | loss=metrics_dict[g]['loss']/max(1,metrics_dict[g]['len'])
140 | precision=metrics_dict[g]['TP']/max(1,metrics_dict[g]['TP']+metrics_dict[g]['FP'])
141 | recall=metrics_dict[g]['TP']/max(1,metrics_dict[g]['TP']+metrics_dict[g]['FN'])
142 | f1=2*precision*recall/(precision+recall) if precision*recall!=0 else 0
143 |
144 | if len(metrics_dict.keys())==1:
145 | x='Total'
146 | total_acc=acc
147 | else:
148 | x=g
149 |
150 | loss_str+=' %s: %.4f' %(x, loss)
151 | acc_str+=' %s: %.4f' %(x, acc)
152 | prec_str+=' %s: %.4f' %(x, precision)
153 | rec_str+=' %s: %.4f' %(x, recall)
154 | f1_str+=' %s: %.4f' %(x, f1)
155 |
156 | if len(metrics_dict.keys())>1:
157 | x='Total'
158 |
159 | acc=sum(f['true'] for f in metrics_dict.values())/max(1,sum(f['len'] for f in metrics_dict.values()))
160 | loss=sum(f['loss'] for f in metrics_dict.values())/max(1,sum(f['len'] for f in metrics_dict.values()))
161 | precision=sum(f['TP'] for f in metrics_dict.values())/max(1,sum(f['TP'] for f in metrics_dict.values()) +sum(f['FP'] for f in metrics_dict.values()))
162 | recall=sum(f['TP'] for f in metrics_dict.values())/max(1,sum(f['TP'] for f in metrics_dict.values())+sum(f['FN'] for f in metrics_dict.values()))
163 | f1=2*precision*recall/(precision+recall) if precision*recall!=0 else 0
164 |
165 | total_acc=acc
166 |
167 | loss_str+=' %s: %.4f' %(x, loss)
168 | acc_str+=' %s: %.4f' %(x, acc)
169 | prec_str+=' %s: %.4f' %(x, precision)
170 | rec_str+=' %s: %.4f' %(x, recall)
171 | f1_str+=' %s: %.4f' %(x, f1)
172 |
173 | return '\n'.join([loss_str, acc_str,prec_str, rec_str, f1_str]), total_acc
174 |
175 | def get_metrics(metrics_dict,name, batch_y, score, loss):
176 | eval_counts=sklearn.metrics.confusion_matrix(batch_y.cpu(),(score>0).cpu(),labels=[0,1])
177 | metrics_dict[name]['len']+=len(batch_y)
178 | metrics_dict[name]['TP']+=eval_counts[1,1]
179 | metrics_dict[name]['FP']+=eval_counts[0,1]
180 | metrics_dict[name]['FN']+=eval_counts[1,0]
181 | metrics_dict[name]['true']=metrics_dict[name]['true']+eval_counts[0,0]+eval_counts[1,1]
182 | metrics_dict[name]['loss']+=loss.item()*len(batch_y)
183 |
--------------------------------------------------------------------------------