├── docs ├── xron │ ├── mkdocs.yml │ └── docs │ │ └── index.md └── images │ └── xron_logo.png ├── xron ├── nrhmm │ ├── models │ │ └── rhmm_mm_norm │ │ │ ├── checkpoint │ │ │ ├── ckpt-8609 │ │ │ ├── transition_matrix_m6A.npy │ │ │ └── transition_matrix_control.npy │ ├── __init__.py │ ├── relabel_test.py │ ├── hmm_eval.py │ ├── split_chunks.py │ ├── tandem_repeat_resquiggle.py │ ├── profile.py │ ├── transition_speed_test.py │ ├── method_illustration.py │ └── hmm_relabel.py ├── _version.py ├── __init__.py ├── utils │ ├── __init__.py │ ├── data_sanity_check.py │ ├── index.py │ ├── dataset_visualization.py │ ├── merge_signal.py │ ├── plot_op.py │ ├── extract_seqs.py │ ├── fast5_op.py │ ├── hidden_plot.py │ ├── transfer_bonito.py │ ├── mm.py │ ├── merge_datasets.py │ ├── index_multi.py │ ├── decode.py │ ├── fastIO.py │ ├── vq.py │ ├── transfer_methylation.py │ ├── sparse_op.py │ └── tagging.py ├── xron_init.py ├── pore_models │ └── transfer.py ├── gen_conf.py ├── config.toml ├── xron_annotate.py ├── entry.py ├── xron_train.py ├── xron_index_shelve.py ├── xron_index_lmdb.py ├── nn.py ├── xron_test.py ├── xron_train_base.py ├── watch_training_progress.py ├── xron_label.py ├── test_VQVAE_speech.py └── xron_train_embedding.py ├── pyproject.toml ├── __init__.py ├── requirements.txt ├── .gitignore ├── setup.py ├── xron-samples ├── methylation_evaluation.py ├── plot_segmentation.py └── summary.py └── README.md /docs/xron/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: My Docs 2 | -------------------------------------------------------------------------------- /docs/images/xron_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haotianteng/Xron/HEAD/docs/images/xron_logo.png -------------------------------------------------------------------------------- /xron/nrhmm/models/rhmm_mm_norm/checkpoint: -------------------------------------------------------------------------------- 1 | latest checkpoint:ckpt-8609 2 | checkpoint file:ckpt-8609 3 | -------------------------------------------------------------------------------- /xron/nrhmm/models/rhmm_mm_norm/ckpt-8609: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haotianteng/Xron/HEAD/xron/nrhmm/models/rhmm_mm_norm/ckpt-8609 -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "Cython>=0.29", "numpy >= 1.15"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /xron/nrhmm/models/rhmm_mm_norm/transition_matrix_m6A.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haotianteng/Xron/HEAD/xron/nrhmm/models/rhmm_mm_norm/transition_matrix_m6A.npy -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Xron Authors. All Rights Reserved. 2 | # 3 | #This Source Code Form is subject to the terms of the GNU General Public License v3.0 4 | 5 | -------------------------------------------------------------------------------- /xron/nrhmm/models/rhmm_mm_norm/transition_matrix_control.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haotianteng/Xron/HEAD/xron/nrhmm/models/rhmm_mm_norm/transition_matrix_control.npy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mappy==2.17 2 | numpy==1.24.4 3 | ont-fast5-api 4 | pandas 5 | fast-ctc-decode 6 | editdistance==0.5.3 7 | boostnano 8 | toml 9 | pysam 10 | flashlight-text 11 | kenlm 12 | numpy==1.24.4 13 | tqdm 14 | scikit-learn 15 | matplotlib 16 | pod5 17 | Biopython 18 | vbz_h5py_plugin -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | xron.egg-info/ 3 | runs/ 4 | dist/ 5 | build/ 6 | models.zip 7 | *.swp 8 | *.pyc 9 | *test*.py 10 | *.vscode 11 | test.npy 12 | tmpy* 13 | *.png 14 | xron/utils/all_loss 15 | xron/utils/nearestembd 16 | xron/utils/nearestembd.py 17 | BoostNano/ 18 | xron.egg* 19 | xron/models/ 20 | dist/ 21 | -------------------------------------------------------------------------------- /xron/nrhmm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Chiron Authors. All Rights Reserved. 2 | # 3 | #This Source Code Form is subject to the terms of the Mozilla Public 4 | #License, v. 2.0. If a copy of the MPL was not distributed with this 5 | #file, You can obtain one at http://mozilla.org/MPL/2.0/. 6 | 7 | """Initial file for the package""" 8 | -------------------------------------------------------------------------------- /xron/_version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Xron(Chiron) Authors. All Rights Reserved. 2 | # 3 | #This Source Code Form is subject to the terms of the Mozilla Public 4 | #License, v. 2.0. If a copy of the MPL was not distributed with this 5 | #file, You can obtain one at http://mozilla.org/MPL/2.0/. 6 | 7 | #Store the version here 8 | __version__ = '1.0.7' 9 | -------------------------------------------------------------------------------- /xron/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Chiron Authors. All Rights Reserved. 2 | # 3 | #This Source Code Form is subject to the terms of the Mozilla Public 4 | #License, v. 2.0. If a copy of the MPL was not distributed with this 5 | #file, You can obtain one at http://mozilla.org/MPL/2.0/. 6 | 7 | """Initial file for the package""" 8 | from xron._version import __version__ -------------------------------------------------------------------------------- /xron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2023 The Xron Authors. All Rights Reserved. 3 | # 4 | #This Source Code Form is subject to the terms of the Mozilla Public 5 | #License, v. 2.0. If a copy of the MPL was not distributed with this 6 | #file, You can obtain one at http://mozilla.org/MPL/2.0/. 7 | 8 | """Initial file for the package""" 9 | from xron._version import __version__ -------------------------------------------------------------------------------- /docs/xron/docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to MkDocs 2 | 3 | For full documentation visit [mkdocs.org](https://www.mkdocs.org). 4 | 5 | ## Commands 6 | 7 | * `mkdocs new [dir-name]` - Create a new project. 8 | * `mkdocs serve` - Start the live-reloading docs server. 9 | * `mkdocs build` - Build the documentation site. 10 | * `mkdocs -h` - Print help message and exit. 11 | 12 | ## Project layout 13 | 14 | mkdocs.yml # The configuration file. 15 | docs/ 16 | index.md # The documentation homepage. 17 | ... # Other markdown pages, images and other files. 18 | -------------------------------------------------------------------------------- /xron/xron_init.py: -------------------------------------------------------------------------------- 1 | #Initialize xron package, need to run this when first time runnning xron 2 | import os 3 | import wget 4 | import xron 5 | import zipfile 6 | MODEL_URL="https://xronmodel.s3.us-east-1.amazonaws.com/models.zip" 7 | MODEL_PATH=xron.__path__[0]+"/models" 8 | def get_models(args): 9 | if not os.path.exists(MODEL_PATH): 10 | os.makedirs(MODEL_PATH) 11 | print("Downloading models...") 12 | wget.download(MODEL_URL, out=MODEL_PATH) 13 | print("\nExtracting models...") 14 | with zipfile.ZipFile(MODEL_PATH+"/models.zip", 'r') as zip_ref: 15 | zip_ref.extractall(MODEL_PATH) 16 | os.remove(MODEL_PATH+"/models.zip") 17 | print("Done!") 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /xron/utils/data_sanity_check.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import os 3 | import numpy as np 4 | import seaborn as sns 5 | from xron.xron_train_supervised import main 6 | import tkinter as tk 7 | from tkinter import filedialog 8 | 9 | root = tk.Tk() 10 | root.withdraw() 11 | 12 | #%% Load data 13 | DF = filedialog.askdirectory() 14 | # DF="/home/heavens/bridge_scratch/ELIGOS_dataset/IVT/control/kmers_guppy_4000_noise/" 15 | chunks = np.load(os.path.join(DF,"chunks.npy"),mmap_mode = "r") 16 | durations = np.load(os.path.join(DF,"durations.npy"),mmap_mode = "r") 17 | kmers = np.load(os.path.join(DF,"kmers.npy"),mmap_mode = "r") 18 | seqs = np.load(os.path.join(DF,"seqs.npy"),mmap_mode = "r") 19 | seq_lens = np.load(os.path.join(DF,"seq_lens.npy"),mmap_mode = "r") 20 | 21 | # %% Sanity check 22 | sns.distplot(seq_lens) -------------------------------------------------------------------------------- /xron/pore_models/transfer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Jun 7 23:59:58 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | 7 | old_pore_f = "/home/heavens/twilight/CMU/Xron/xron/pore_models/control_hye.1.5mer_level_table.txt" 8 | new_pore_f = "/home/heavens/twilight/CMU/Xron/xron/pore_models/5mer_level_table.model" 9 | title = '\t'.join(['kmer','level_mean','level_stdv','sd_mean']) 10 | with open(old_pore_f,'r') as f: 11 | with open(new_pore_f,'w+') as wf: 12 | wf.write(title+'\n') 13 | for line in f: 14 | split_line = line.strip().split() 15 | kmer = split_line[0] 16 | mean = split_line[2] 17 | std_line = next(f).strip().split() 18 | assert std_line[0] == kmer 19 | std = std_line[2] 20 | dwell_line = next(f).strip().split() 21 | assert dwell_line[0] == kmer 22 | dwell = dwell_line[2] 23 | wf.write('\t'.join([kmer,mean,std,dwell])+'\n') 24 | -------------------------------------------------------------------------------- /xron/gen_conf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun Dec 26 19:38:25 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | import sys 8 | import argparse 9 | from xron_model import MM_CONFIG 10 | 11 | def main(args): 12 | base_config = MM_CONFIG() 13 | grid_key = [base_config.CNN['Layers'][0]['out_channels'], 14 | base_config.CNN['Layers'][1]['out_channels'], 15 | base_config.CNN['Layers'][2]['out_channels'], 16 | base_config.RNN['hidden_size'], 17 | base_config.PORE_MODEL['K']] 18 | grid = [[4,16,64], 19 | [16,64,128], 20 | [256,512,768], 21 | [256,512,768], 22 | [2,3,4,5]] 23 | 24 | 25 | # if __name__ == "__main__": 26 | # parser = argparse.ArgumentParser( 27 | # description='Calling training module.') 28 | # parser.add_argument('--module', required = True, 29 | # help = "The training module to call, can be Embedding, Supervised and Reinforce") 30 | # parser.add_argument('-o', '--model_folder', required = True, 31 | # help = "The folder to save folder at.") 32 | # args = parser.parse_args(sys.argv[1:]) 33 | # os.makedirs(args.model_folder,exist_ok=True) 34 | # main(args) -------------------------------------------------------------------------------- /xron/utils/index.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Oct 3 11:47:48 2022 5 | 6 | @author: heavens 7 | """ 8 | import os 9 | import h5py 10 | import argparse 11 | from xron.utils.seq_op import fast5_shallow_iter 12 | from xron.utils.fastIO import read_fastq 13 | 14 | def indexing(args): 15 | fastq_records = read_fastq(args.fastq) 16 | fast5_records = {} 17 | print("Indexing fastq files.") 18 | with open(args.fastq+'.index', 'w+') as f: 19 | for root,abs_path in fast5_shallow_iter(args.fast5,tqdm_bar = True): 20 | read_ids = list(root.keys()) 21 | for id in read_ids: 22 | fast5_records[id[5:]] = abs_path #read_id is like "read_00000000-0000-0000-0000-0000000000" 23 | for id in fastq_records['name']: 24 | if id in fast5_records.keys(): 25 | f.write(id+'\t'+fast5_records[id]+'\n') 26 | else: 27 | raise KeyError('fastq readid %s not found in fast5'%(id)) 28 | print("Indexing fastq file has been finished.") 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser(description='Process some integers.') 32 | parser.add_argument('--fast5', required = True, type=str, help='folder that contains fast5 output') 33 | parser.add_argument('--fastq', required = True, type=str, help='The merged fastq file') 34 | args = parser.parse_args() 35 | indexing(args) -------------------------------------------------------------------------------- /xron/utils/dataset_visualization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri Nov 19 22:51:28 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import numpy as np 7 | from itertools import product,repeat 8 | 9 | dataset = "/home/heavens/bridge_scratch/m6A_Nanopore/merged_dataset_diff/" 10 | chunks_f = dataset + "chunks.npy" 11 | seq_lens_f = dataset + "seq_lens.npy" 12 | seqs_f = dataset + "seqs.npy" 13 | chunks = np.load(chunks_f) 14 | seq_lens = np.load(seq_lens_f) 15 | seqs = np.load(seqs_f) 16 | 17 | def check_ratio(sequences): 18 | control_count = 0 19 | m_count = 0 20 | other_count = 0 21 | for seq in sequences: 22 | if 'A' in seq: 23 | if 'M' in seq: 24 | raise("ValueError: The read has both M and A in sequence.") 25 | control_count +=1 26 | elif 'M' in seq: 27 | m_count += 1 28 | else: 29 | other_count +=1 30 | return control_count,m_count, control_count + m_count+other_count 31 | 32 | def check_Mkmer(sequences,k = 5): 33 | N = ['M','C','G','T'] 34 | kmers = product(*repeat(N,k)) 35 | kmer_dict = {} 36 | for kmer in kmers: 37 | kmer_dict[''.join(kmer)] = 0 38 | for seq in sequences: 39 | occurs = [i for i,c in enumerate(seq) if c=='M'] 40 | for o in occurs: 41 | curr_kmer = seq[o:o+k] 42 | if len(curr_kmer) == k: 43 | kmer_dict[curr_kmer] +=1 44 | return kmer_dict 45 | 46 | c,m,all = check_ratio(seqs) 47 | kmer_dict = check_Mkmer(seqs) -------------------------------------------------------------------------------- /xron/utils/merge_signal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Tue Aug 3 23:24:39 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import numpy as np 7 | import os 8 | import sys 9 | import argparse 10 | 11 | 12 | def main(args): 13 | chunks = [] 14 | counts = [] 15 | for i in args.input: 16 | try: 17 | c = np.load(i) 18 | except FileNotFoundError: 19 | print("Didn't find %s, skip."%(i)) 20 | continue 21 | counts.append(len(c)) 22 | chunks.append(c) 23 | print("Read %d chunks from %s"%(len(c),i)) 24 | if args.equal_size: 25 | min_size = min(counts) 26 | chunks = [x[:min_size] for x in chunks] 27 | chunk_all = np.concatenate(chunks,axis = 0) 28 | np.save(os.path.join(args.output,'chunks_all.npy'),chunk_all) 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser( 32 | description='Merge signal chunks from different folder.') 33 | parser.add_argument('-i', '--input', default = None, 34 | help = "The input signal chunks, separate by comma.") 35 | parser.add_argument('-o', '--output', required = True, 36 | help = "The output folder of the merged dataset.") 37 | parser.add_argument('--equal_size',default = False, type = bool, 38 | help = "If make the size from each chunks equal.") 39 | args = parser.parse_args(sys.argv[1:]) 40 | args.input = args.input.split(',') 41 | if not os.path.exists(args.output): 42 | os.makedirs(args.output) 43 | main(args) -------------------------------------------------------------------------------- /xron/utils/plot_op.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Wed Sep 28 07:34:49 2022 5 | 6 | @author: heavens 7 | """ 8 | from matplotlib import pyplot as plt 9 | import numpy as np 10 | 11 | def auc_plot(TP,FP,axs,add_head = False,print_AUC = True,**kwargs): 12 | c = kwargs['color'] if 'color' in kwargs.keys() else None 13 | label = kwargs['label'] if 'color' in kwargs.keys() else None 14 | linewidth = kwargs['linewidth'] if 'linewidth' in kwargs.keys() else 1 15 | TP = np.asarray(TP) 16 | FP = np.asarray(FP) 17 | assert np.all(TP[1:] - TP[:-1]>=0) 18 | assert np.all(FP[1:] - FP[:-1]>=0) 19 | if add_head: 20 | TP = np.concatenate(([0],TP,[1])) 21 | FP = np.concatenate(([0],FP,[1])) 22 | axs.plot(FP[:2],TP[:2],color = c,label = label,linewidth = linewidth) 23 | axs.set_xlabel("False Positive") 24 | axs.set_ylabel("True Positive") 25 | axs.set_xlim([-0.05, 1.05]) 26 | axs.set_ylim([-0.05, 1.05]) 27 | if print_AUC: 28 | axs.text(x = 0.8,y = 0,s = "AUC = %.2f"%(AUC(TP,FP))) 29 | axs.plot([-0.05,1.05],[-0.05,1.05],color = 'grey') 30 | 31 | def AUC(TP,FP): 32 | """Calculate Area under curve given the true positive and false positive 33 | array 34 | """ 35 | TP = TP[::-1] if TP[0]>TP[-1] else TP 36 | FP = FP[::-1] if FP[0]>FP[-1] else FP 37 | TP = [0] + TP if TP[0] != 0 else TP 38 | TP = TP + [1] if TP[-1] != 1 else TP 39 | FP = [0] + FP if FP[0] != 0 else FP 40 | FP = FP + [1] if FP[-1] != 1 else FP 41 | FP = np.asarray(FP) 42 | TP = np.asarray(TP) 43 | return np.sum((TP[1:] + TP[:-1])*(FP[1:]-FP[:-1])/2) 44 | -------------------------------------------------------------------------------- /xron/utils/extract_seqs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Jan 26 12:55:10 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os,sys,argparse 7 | from tqdm import tqdm 8 | import h5py 9 | import numpy as np 10 | from xron.utils.seq_op import fast5_iter 11 | from pathlib import Path 12 | 13 | 14 | def retrive_fastq(seq_h): 15 | try: 16 | seq = np.asarray(seq_h['BaseCalled_template']['Fastq']).tobytes().decode('utf-8') 17 | except: 18 | seq = str(np.asarray(seq_h['BaseCalled_template']['Fastq']).astype(str)) 19 | return seq 20 | 21 | def extract(args): 22 | iterator = fast5_iter(args.input_fast5,mode = 'r') 23 | for read_h,signal,fast5_f,read_id in tqdm(iterator): 24 | seq = retrive_fastq(read_h['Analyses/Basecall_1D_%s'%(args.basecall_entry)]) 25 | with open(os.path.join(args.output,Path(fast5_f).stem+'.fastq'),'a+') as f: 26 | f.write(seq) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(prog='xron', 31 | description='A Unsupervised Nanopore basecaller.') 32 | parser.add_argument('-i', 33 | '--input_fast5', 34 | required = True, 35 | help="File path or Folder path to the fast5 file.") 36 | parser.add_argument('-o', 37 | '--output', 38 | required = True, 39 | help="Output folder.") 40 | parser.add_argument('--basecall_entry', 41 | default = "000", 42 | help="The entry number in /Analysis/ to look into, for\ 43 | example 000 means looking for Basecall_1D_000.") 44 | 45 | FLAGS = parser.parse_args(sys.argv[1:]) 46 | os.makedirs(FLAGS.output,exist_ok = True) 47 | extract(FLAGS) -------------------------------------------------------------------------------- /xron/utils/fast5_op.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Mar 15 16:47:56 2022 5 | 6 | @author: haotian 7 | """ 8 | import h5py 9 | import os 10 | import numpy as np 11 | from collections import defaultdict 12 | from matplotlib import pyplot as plt 13 | from tqdm import tqdm 14 | 15 | 16 | def seq2kmer(seq:str,k:int = 5): 17 | return [seq[x:x+k] for x,_ in enumerate(seq[:-k+1])] 18 | 19 | prefix = "/home/heavens/bridge_scratch/NA12878_RNA_IVT/guppy_train/fast5s" 20 | kmer_dict_p1 = defaultdict(list) 21 | kmer_dict_m1 = defaultdict(list) 22 | k = 5 23 | for f in tqdm(os.listdir(prefix)): 24 | if f.endswith("fast5"): 25 | with h5py.File(os.path.join(prefix,f),'r') as root: 26 | for read in root: 27 | try: 28 | corrected = root[read]['Analyses/Segmentation_000/Reference_corrected'] 29 | signal = np.asarray(root[read]['Raw/Signal'])[::-1] 30 | seq =np.asarray(corrected['ref_seq']).item().decode() 31 | kmers = seq2kmer(seq,k = k) 32 | ref_sig_idx = np.asarray(corrected['ref_sig_idx']) 33 | for i in np.arange(len(kmers)): 34 | c,c_n = ref_sig_idx[i+k//2],ref_sig_idx[i+k//2+1] 35 | kmer = kmers[i] 36 | kmer_dict_p1[kmer]+= list(signal[c:c_n]) 37 | for i in np.arange(len(kmers)): 38 | c,c_n = ref_sig_idx[i+k//2-1],ref_sig_idx[i+k//2] 39 | kmer = kmers[i] 40 | kmer_dict_m1[kmer]+= list(signal[c:c_n]) 41 | 42 | except KeyError: 43 | pass 44 | 45 | 46 | std_m1 = [] 47 | std_p1 = [] 48 | for key,val in kmer_dict_m1.items(): 49 | std_m1.append(np.std(val)) 50 | std_p1.append(np.std(kmer_dict_p1[key])) -------------------------------------------------------------------------------- /xron/utils/hidden_plot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Mar 28 23:54:38 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import h5py 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | strides = 11 10 | root = h5py.File("/home/heavens/Documents/UCSC_Run1_20180129_IVT_RNA_201.fast5",'r') 11 | for read_id in root: 12 | if read_id == "read_40634cc4-ff9c-41f3-bd2c-901cb23c0179": 13 | plt.figure(figsize = (40,20)) 14 | plt.plot(np.asarray(root[read_id]['Raw/Signal'])[::-1]) 15 | read_h = root[read_id] 16 | basecalled = read_h['Analyses/Basecall_1D_000/BaseCalled_template'] 17 | hidden = np.asarray(basecalled['Hidden']) 18 | signal = np.asarray(read_h['Raw/Signal']) 19 | move = basecalled['Move'] 20 | position_sig = np.repeat(np.cumsum(move)-1,repeats = strides).astype(np.int32)[:len(signal)] 21 | segmentation = np.where(np.diff(position_sig))[0] 22 | 23 | plt.figure(figsize = (80,40)) 24 | max_T = 40000 25 | plt.plot(signal[::-1][-(max_T+5000):],color = "red") 26 | plt.vlines(x = segmentation[segmentation>(segmentation[-1]-max_T)] - len(signal)+max_T, 27 | ymin = min(signal), 28 | ymax = max(signal), 29 | color = "green", 30 | linestyles = "dotted", 31 | linewidth = 0.01, 32 | antialiased=False) 33 | 34 | plt.figure(figsize = (80,40)) 35 | max_T = 40000 36 | plt.plot(signal[:max_T],color = "red") 37 | segmentation = segmentation[-1] - segmentation[::-1] 38 | plt.vlines(x = segmentation[segmentation=2.10.0', 18 | 'numpy==1.24.4', 19 | 'statsmodels>=0.8.0', 20 | 'tqdm>=4.23.0', 21 | 'scipy>=1.0.1', 22 | 'biopython==1.73', 23 | 'google-auth==2.18.1', 24 | 'oauthlib==3.2.2', 25 | 'packaging>=18.0', 26 | 'ont-fast5-api>=0.3.1', 27 | 'wget>=3.2', 28 | 'pysam>=0.21.0', 29 | 'tensorboard', 30 | 'matplotlib', 31 | 'seaborn', 32 | 'pandas', 33 | 'toml', 34 | 'fast-ctc-decode', 35 | 'torch>=1.12.0', 36 | 'torchvision>=0.13.0', 37 | 'torchaudio>=0.12.0', 38 | 'boostnano', 39 | 'editdistance==0.6.1', 40 | 'boostnano', 41 | 'vbz_h5py_plugin', 42 | ] 43 | exec(open('xron/_version.py').read()) #readount the __version__ variable 44 | setup( 45 | name = 'xron', 46 | packages = find_packages(exclude=["*.test", "*test.*", "test.*", "test"]), 47 | version = __version__, 48 | include_package_data=True, 49 | description = 'A deep neural network basecaller for nanopore sequencing.', 50 | author = 'Haotian Teng', 51 | author_email = 'havens.teng@gmail.com', 52 | url = 'https://github.com/haotianteng/Xron', 53 | download_url = 'https://github.com/haotianteng/Xron/archive/1.0.0.tar.gz', 54 | keywords = ['basecaller', 'nanopore', 'sequencing','neural network','RNA methylation'], 55 | license="GPL 3.0", 56 | classifiers = ['License :: OSI Approved :: GNU General Public License v3 (GPLv3)'], 57 | install_requires=install_requires, 58 | entry_points={'console_scripts':['xron=xron.entry:main'],}, 59 | long_description=long_description, 60 | include_dirs = [np.get_include()], 61 | long_description_content_type='text/markdown', 62 | ) 63 | -------------------------------------------------------------------------------- /xron/nrhmm/relabel_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Tue Jun 28 02:32:18 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import toml 7 | import numpy as np 8 | import seaborn as sns 9 | from matplotlib import pyplot as plt 10 | from xron.nrhmm.hmm_relabel import get_effective_kmers 11 | 12 | 13 | chunk_control = np.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/control/rep1/kmers_xron_4000_noise/chunks.npy") 14 | chunk_m6A = np.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/m6A/rep1/kmers_guppy_4000_dwell/chunks.npy") 15 | kmer_m6A = np.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/m6A/rep1/kmers_guppy_4000_dwell/kmers.npy") 16 | chunk_control = np.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/control/rep1/kmers_guppy_4000_dwell/chunks.npy") 17 | chunk_control_renorm = np.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/control/rep1/kmers_guppy_4000_dwell/chunks_renorm.npy") 18 | kmer_control = np.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/control/rep1/kmers_guppy_4000_dwell/kmers.npy") 19 | chunk_control_NA = np.load("/home/heavens/bridge_scratch/NA12878_RNA_IVT/guppy_train/kmers_guppy_4000_dwell/chunks.npy") 20 | config = toml.load("/home/heavens/bridge_scratch/ime4_Yearst/IVT/m6A/rep1/kmers_guppy_4000_dwell/config.toml") 21 | chunk_control_xron = np.load("/home/heavens/bridge_scratch/NA12878_RNA_IVT/xron_output/kmers_guppy_4000_dwell/chunks_renorm.npy") 22 | chunk_m6A90 = np.load("/home/heavens/bridge_scratch/m6A_Nanopore_RNA002/data/m6A_90_pct/20210430_1745_X2_FAQ15454_23428362/kmers_guppy_4000_dwell/chunks.npy") 23 | idx2kmer = config['idx2kmer'] 24 | effective_kmers = get_effective_kmers("!MA",idx2kmer) 25 | 26 | figs,axs = plt.subplots(ncols = 3,sharey = True,figsize = (10,5)) 27 | sns.distplot(np.mean(chunk_m6A_renorm,axis = 1),ax = axs[0],label = "m6A-Eva") 28 | sns.distplot(np.mean(chunk_control_renorm,axis = 1),ax = axs[0],label = "control-Eva") 29 | axs[0].legend() 30 | sns.distplot(np.mean(chunk_m6A_renorm,axis = 1),ax = axs[1],label = "m6A-Eva") 31 | sns.distplot(np.mean(chunk_control_xron,axis = 1),ax = axs[1],label = "control-NA12878") 32 | axs[1].legend() 33 | sns.distplot(np.mean(chunk_m6A90,axis = 1),ax = axs[2],label = "m6A-NA12878") 34 | sns.distplot(np.mean(chunk_control_xron,axis = 1),ax = axs[2],label = "control-xron-NA12878") 35 | axs[2].legend() 36 | plt.legend() -------------------------------------------------------------------------------- /xron/utils/transfer_bonito.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri Aug 13 16:03:48 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os,sys 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | import argparse 10 | def map_base(string): 11 | base_dict = {'A':1,'C':2,'G':3,'T':4} 12 | return np.asarray([base_dict[b] for b in string]) 13 | 14 | 15 | #bonito_ref = "/home/heavens/twilight_data1/S10_DNA/20170322_c4_watermanag_S10/bonito_output/ctc_data/references.npy" 16 | #bonito_len = "/home/heavens/twilight_data1/S10_DNA/20170322_c4_watermanag_S10/bonito_output/ctc_data/reference_lengths.npy" 17 | #ref = np.load(bonito_ref) 18 | #ref_len = np.load(bonito_len) 19 | #plt.hist(ref_len) 20 | def main(args): 21 | print("Read the chunks...") 22 | chunk_f = os.path.join(args.input,'chunks.npy') 23 | seq_f = os.path.join(args.input,'seqs.npy') 24 | seq_len_f = os.path.join(args.input,'seq_lens.npy') 25 | seq = np.load(seq_f) 26 | seq_len = np.load(seq_len_f) 27 | chunk = np.load(chunk_f) 28 | print("Process...") 29 | plt.figure() 30 | plt.hist(seq_len,bins = 80) 31 | mask = np.logical_and(seq_len5) 32 | seq = seq[mask] 33 | seq_len_filt = seq_len[mask] 34 | plt.figure() 35 | plt.hist(seq_len,bins = 80) 36 | pad_w = np.max(seq_len) 37 | seq_filt = [map_base(s) for s in seq] 38 | seq_filt = [np.pad(x,(0,pad_w-len(x))) for x in seq_filt] 39 | seq_filt = np.asarray(seq_filt) 40 | print("Write output.") 41 | os.makedirs(args.output,exist_ok = True) 42 | chunk_out_f = os.path.join(args.output,'chunks.npy') 43 | seq_out_f = os.path.join(args.output,'references.npy') 44 | seq_len_out_f = os.path.join(args.output,'reference_lengths.npy') 45 | np.save(seq_out_f,seq_filt) 46 | np.save(seq_len_out_f,seq_len_filt) 47 | np.save(chunk_out_f,chunk[mask]) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser( 52 | description='Transfer dataset into bonito-readable format.') 53 | parser.add_argument('-i', '--input', required = True, 54 | help = "The folder contains the chunks,seqs and seq_lens file.") 55 | parser.add_argument('-o', '--output', required = True, 56 | help = "The output folder.") 57 | args = parser.parse_args(sys.argv[1:]) 58 | main(args) -------------------------------------------------------------------------------- /xron/nrhmm/hmm_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon May 9 06:15:29 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | import sys 8 | import time 9 | import toml 10 | import torch 11 | import argparse 12 | import itertools 13 | import numpy as np 14 | from matplotlib import pyplot as plt 15 | from xron.nrhmm.hmm import GaussianEmissions, RHMM 16 | from xron.nrhmm.hmm_input import Kmer2Transition, Kmer_Dataset 17 | from torchvision import transforms 18 | from torch.utils.data.dataloader import DataLoader 19 | from xron.xron_train_base import DeviceDataLoader 20 | 21 | 22 | 23 | 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser( 27 | description='Evaluating using the RHMM model') 28 | parser.add_argument("-i","--input", type = str, required = True, 29 | help = "Data folder contains the chunk, kmer sequence.") 30 | parser.add_argument('-o', '--output', required = True, 31 | help = "The folder to save folder at.") 32 | parser.add_argument("-b","--batch_size", type = int, default = 50, 33 | help = "The batch size to train.") 34 | parser.add_argument("--lr",type = float, default = 1e-2, 35 | help = "The initial training learning rate.") 36 | parser.add_argument("--report",type = int, default = 10, 37 | help = "Report the loss and save the model every report cycle.") 38 | parser.add_argument("--certain_methylation",action = "store_false", 39 | dest = "kmer_replacement", 40 | help = "If we are sure about the methylation state.") 41 | parser.add_argument("--optimizer",type = str, default = "Adam", 42 | help = "The optimizer used to train the model.") 43 | parser.add_argument('--epoches', default = 10, type = int, 44 | help = "The number of epoches to train.") 45 | parser.add_argument("--device",type = str, default = "cuda", 46 | help = "The device used to train the model.") 47 | parser.add_argument('--load', dest='retrain', action='store_true', 48 | help='Load existed model.') 49 | parser.add_argument('--moving_average', type = float, default = 0.0, 50 | help="The factor of moving average, 0 means no delay.") 51 | parser.add_argument('--trainable_bases', type = str, default = None, 52 | help="A magic string AB!CD or AB that gives the trainable bases and NOT trainalbe bases, for example MC means trains on kmer that must contains M and C, A!M means trains on kmer that must contains A but not contains M.") 53 | args = parser.parse_args(sys.argv[1:]) 54 | os.makedirs(args.model_folder,exist_ok=True) 55 | train(args) 56 | -------------------------------------------------------------------------------- /xron/nrhmm/split_chunks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jul 1 07:32:52 2022 5 | 6 | @author: haotian teng 7 | """ 8 | 9 | import os 10 | import sys 11 | import toml 12 | import numpy as np 13 | import argparse 14 | 15 | def split(args): 16 | collections = {} 17 | fs = os.listdir(args.input) 18 | for f in fs: 19 | if f.endswith(".npy"): 20 | try: 21 | collections[f] = np.load(os.path.join(args.input,f)) 22 | except ValueError: 23 | continue 24 | s = np.array([len(x) for x in collections.values()]) 25 | if not(np.all(s==s[0])): 26 | print("Warning the npy files inside the folder has different size.") 27 | for i in np.arange(0,max(s),args.batch_size): 28 | sub_f = os.path.join(args.input,args.prefix+str(i//args.batch_size)) 29 | os.makedirs(sub_f,exist_ok=True) 30 | for f,data in collections.items(): 31 | np.save(os.path.join(sub_f,f),data[i:i+args.batch_size]) 32 | 33 | def merge(args): 34 | fs = os.listdir(args.input) 35 | fs = [x for x in fs if args.prefix in x] 36 | i = 0 37 | collections = {} 38 | while args.prefix+str(i) in fs: 39 | sub_f = os.path.join(args.input,args.prefix + str(i)) 40 | for f in os.listdir(sub_f): 41 | if f.endswith(".npy"): 42 | pieces = np.load(os.path.join(sub_f,f)) 43 | if f in collections.keys(): 44 | collections[f].append(pieces) 45 | else: 46 | collections[f] = [pieces] 47 | i+=1 48 | for key in collections.keys(): 49 | print("merge %s"%(key)) 50 | if collections[key][0].ndim > 1: 51 | np.save(os.path.join(args.input,key),np.vstack(tuple(collections[key]))) 52 | else: 53 | np.save(os.path.join(args.input,key),np.hstack(tuple(collections[key]))) 54 | 55 | def main(args): 56 | if args.merge: 57 | merge(args) 58 | else: 59 | split(args) 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser( 63 | description='Training RHMM model') 64 | parser.add_argument("-i","--input", type = str, required = True, 65 | help = "Data folder contains the chunk, kmer sequence.") 66 | parser.add_argument("-b","--batch_size", type = int, default = 4000, 67 | help = "The batch size of each subfolder.") 68 | parser.add_argument("--prefix", type = str, default = "subdata", 69 | help = "The prefix of the sub-folders.") 70 | parser.add_argument("--reverse",action = "store_true", dest = "merge", 71 | help = "Reverse the split operation, surrogate the data.") 72 | args = parser.parse_args(sys.argv[1:]) 73 | main(args) 74 | -------------------------------------------------------------------------------- /xron-samples/methylation_evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Jul 20 00:42:56 2022 3 | This script is used to evaluate the methylation ratio of the basecalled reads for partially methylated IVT dataset. 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | import pandas as pd 8 | from tqdm import tqdm 9 | import numpy as np 10 | from xron.utils.fastIO import read_fastqs 11 | import seaborn as sns 12 | from matplotlib import pyplot as plt 13 | 14 | def m6A_ratio(sequences,reduced_sum = True): 15 | M_count = [x.count("M") for x in sequences] 16 | A_count = [x.count("A") for x in sequences] 17 | if reduced_sum: 18 | return np.sum(M_count)/(np.sum(A_count) + np.sum(M_count)) 19 | else: 20 | return np.asarray(M_count)/(np.asarray(A_count) + np.asarray(M_count)) 21 | 22 | if __name__ == "__main__": 23 | scratch = os.environ['SCRATCH'] 24 | 25 | x = [0,0.25,0.50,0.75,0.90,1.0] 26 | m6A_fs = [f"{scratch}/ime4_Yearst/IVT/control/rep1/xron_crosslink/fastqs", 27 | f"{scratch}/m6A_Nanopore_RNA002/data/m6A_25_pct/20210430_1751_X2_FAP66339_8447fb8b/xron_crosslink/fastqs", 28 | f"{scratch}/m6A_Nanopore_RNA002/data/m6A_50_pct/20210430_1751_X3_FAQ16600_fe8f7999/xron_crosslink/fastqs", 29 | f"{scratch}/m6A_Nanopore_RNA002/data/m6A_75_pct/20210430_1745_X1_FAQ15457_c865db38/xron_crosslink/fastqs", 30 | f"{scratch}/m6A_Nanopore_RNA002/data/m6A_90_pct/20210430_1745_X2_FAQ15454_23428362/xron_crosslink/fastqs", 31 | f"{scratch}/ime4_Yearst/IVT/m6A/rep1/xron_crosslink/fastqs"] 32 | 33 | # x = [0,1.0] 34 | # m6A_fs = ["/home/heavens/bridge_scratch/ime4_Yearst/IVT/control/rep1/xron_crosslink_finetune/fastqs", 35 | # "/home/heavens/bridge_scratch/ime4_Yearst/IVT/m6A/rep1/xron_crosslink_finetune/fastqs"] 36 | 37 | # x = [0,0.3] 38 | # m6A_fs = ["/home/heavens/bridge_scratch/ime4_Yearst/Yearst/ko_raw_fast5/xron_crosslink/fastqs", 39 | # "/home/heavens/bridge_scratch/ime4_Yearst/Yearst/wt_raw_fast5/xron_crosslink/fastqs"] 40 | 41 | records = [read_fastqs(x) for x in tqdm(m6A_fs)] 42 | ratios = [m6A_ratio(x['sequences'],reduced_sum = False) for x in tqdm(records)] 43 | 44 | mix_prop = [[xi]*len(ratios[i]) for i,xi in enumerate(x)] 45 | ratios = np.concatenate(ratios,axis = 0) 46 | mix_prop = np.concatenate(mix_prop,axis = 0) 47 | df = pd.DataFrame({"basecall_ratio": ratios,"prepare_ratio": mix_prop}) 48 | 49 | fig, axes = plt.subplots(figsize=(5,5)) 50 | # sns.violinplot(data = df,x = "mix_prop",y = "ratios",showmeans = True,showmedians = True) 51 | sns.boxplot(data = df,x = "prepare_ratio",y = "basecall_ratio",showfliers = False,ax = axes) 52 | axes.set_xlabel("m6A propertion during IVT") 53 | axes.set_ylabel("Basecalled m6A ratio") 54 | fig.savefig("/home/heavens/bridge_scratch/Xron_Project/benchmark/IVT_read/methylation_ratio.png",dpi = 300) 55 | fig.savefig(f"{scratch}/Xron_Project/benchmark/IVT_read/methylation_ratio.pdf",dpi = 300,format = "pdf") 56 | -------------------------------------------------------------------------------- /xron/utils/mm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sat Mar 26 14:24:27 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | 7 | import h5py 8 | from xron.utils import seq_op 9 | import numpy as np 10 | from umap import UMAP 11 | from sklearn.manifold import TSNE 12 | from matplotlib import pyplot as plt 13 | import matplotlib.cm as cm 14 | from typing import List 15 | 16 | fast5_f = "/home/heavens/Documents/FAQ15457_pass_6cc85380_1.fast5" 17 | root = h5py.File(fast5_f,mode = 'r') 18 | k = 5 19 | strides = 11 20 | for read_id in root: 21 | read_h = root[read_id] 22 | basecalled = read_h['Analyses/Basecall_1D_001/BaseCalled_template'] 23 | hidden = np.asarray(basecalled['Hidden']) 24 | signal = np.asarray(read_h['Raw/Signal']) 25 | move = basecalled['Move'] 26 | position_sig = np.repeat(np.cumsum(move)-1,repeats = strides).astype(np.int32)[:len(signal)] 27 | segmentation = np.where(np.diff(position_sig))[0] 28 | position = np.cumsum(move) 29 | fastq = np.asarray(basecalled['Fastq']).item().decode().split()[1] 30 | kmers_list = seq_op.kmers2array(seq_op.seq2kmers(fastq,k = k)) 31 | # kmers_list = [0]*((k-1)//2) + kmers_list + [0]*((k-1)//2) 32 | kmers_list = [0]*(k-1) + kmers_list 33 | logit = 1/(1+1/np.exp(hidden)) 34 | 35 | ## Mask to display partail 36 | max_pos = 10 37 | hidden = hidden[position args.max[f_i]: 32 | print("Thresholding it to %d instances"%(args.max[f_i])) 33 | pieces = pieces[:args.max[f_i]] 34 | collections[key].append(pieces) 35 | shapes = [sum([len(y) for y in x]) for x in collections.values()] 36 | assert len(np.unique(shapes)) == 1 37 | 38 | for key in collections.keys(): 39 | print("merge %s"%(key)) 40 | if collections[key][0].ndim > 1: 41 | np.save(os.path.join(args.output,key),np.vstack(tuple(collections[key]))) 42 | else: 43 | np.save(os.path.join(args.output,key),np.hstack(tuple(collections[key]))) 44 | 45 | def main(args): 46 | merge(args) 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser( 50 | description='Training RHMM model') 51 | parser.add_argument("-i","--input", type = str, required = True, 52 | help = "Data folder contains npy file, separated by comma.") 53 | parser.add_argument("-o","--output",required = True, type = str, 54 | help = "The output folder to store the merged dataset.") 55 | parser.add_argument("-k","--key", type = str, default = "chunks,path,seqs,seq_lens,durations", 56 | help = "The name of npy items need to be collected, separated by comma.") 57 | parser.add_argument("-m","--max", default = None, 58 | help = "The maximum number of instances to be include in each dataset, can be a list of int separated by commas specify maximum number for each datasets.") 59 | args = parser.parse_args(sys.argv[1:]) 60 | os.makedirs(args.output,exist_ok = True) 61 | args.fs = args.input.strip().split(',') 62 | args.keys = args.key.strip().split(',') 63 | if args.max: 64 | if ',' in args.max: 65 | args.max = [int(x) for x in args.max.split(',')] 66 | else: 67 | args.max = [int(args.max)] * len(args.fs) 68 | assert len(args.fs) == len(args.max) 69 | else: 70 | args.max = None 71 | main(args) 72 | -------------------------------------------------------------------------------- /xron-samples/plot_segmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Spyder Editor 4 | 5 | This is a temporary script file. 6 | """ 7 | import os 8 | import numpy as np 9 | from matplotlib import pyplot as plt 10 | from matplotlib import rcParams 11 | import toml 12 | rcParams['figure.dpi'] = 300 13 | plt.rcParams["font.family"] = "Arial" 14 | #make font bold 15 | plt.rcParams["font.weight"] = "bold" 16 | BRIDGE_F=os.environ['SCRATCH'] 17 | chunks_c = np.load(f"{BRIDGE_F}/ime4_Yearst/IVT/control/rep2/kmers_guppy_4000_noise/chunks.npy",mmap_mode = "r") 18 | paths_c = np.load(f"{BRIDGE_F}/ime4_Yearst/IVT/control/rep2/kmers_guppy_4000_noise/path.npy",mmap_mode = "r") 19 | chunks_m = np.load(f"{BRIDGE_F}/ime4_Yearst/IVT/m6A/rep2/kmers_guppy_4000_noise/chunks.npy",mmap_mode = "r") 20 | paths_m = np.load(f"{BRIDGE_F}/ime4_Yearst/IVT/m6A/rep2/kmers_guppy_4000_noise/path.npy",mmap_mode = "r") 21 | config = toml.load(f"{BRIDGE_F}/ime4_Yearst/IVT/control/rep2/kmers_guppy_4000_noise/config.toml") 22 | idx2kmer = config['idx2kmer'] 23 | fig,axs = plt.subplots(nrows = 2, figsize = (20,6.6)) 24 | axs[0].axis('off') 25 | axs[1].axis('off') 26 | 27 | def plot_segmentation(chunks,paths,test_i,T,ax,control = False,m6A = False): 28 | path = paths[test_i][:T] 29 | chunk = chunks[test_i][:T] 30 | move = path[1:] != path[:-1] 31 | pos = np.where(move)[0] 32 | pos = np.append(pos,T) 33 | ax.plot(chunk,color = "#A0B1BA",lw = 0.7) 34 | pre_p = 0 35 | color_map = {"A":"#FFC61E","G":"#AF58BA","C":"#009ADE","T":"#FF1F59","M":"#F28522"} 36 | i = 0 37 | for p in pos: 38 | b = idx2kmer[path[p-1]][2] 39 | if control: 40 | if b == "M": 41 | b = 'A' 42 | elif m6A: 43 | if b == "A": 44 | b = 'M' 45 | ax.axvspan(pre_p+0.01, p-0.01, color=color_map[b], alpha=1.0,lw = 0) 46 | offset = 15 47 | if p-pre_pTP[-1] else TP 24 | FP = FP[::-1] if FP[0]>FP[-1] else FP 25 | TP = [0] + TP if TP[0] != 0 else TP 26 | TP = TP + [1] if TP[-1] != 1 else TP 27 | FP = [0] + FP if FP[0] != 0 else FP 28 | FP = FP + [1] if FP[-1] != 1 else FP 29 | FP = np.asarray(FP) 30 | TP = np.asarray(TP) 31 | return np.sum((TP[1:] + TP[:-1])*(FP[1:]-FP[:-1])/2) 32 | 33 | def posterior_decode(posterior, 34 | M_threshold:float = 0.5): 35 | """Decode the posterior probability to get the modified ratio""" 36 | called = posterior > M_threshold 37 | return called.sum()/called.size 38 | 39 | def run(args): 40 | control_fast5 = args.control 41 | modified_fast5 = args.positive 42 | output = args.output 43 | label = "" if args.label is None else "_"+args.label 44 | n_total = np.inf if args.max_n == -1 else args.max_n 45 | TP,FP = [],[] 46 | c_p,m_p = [],[] 47 | for read_h,signal,abs_path,read_id in fast5_iter(control_fast5,mode = 'r',tqdm_bar = True): 48 | try: 49 | p = read_entry(read_h,entry = "ModifiedProbability",index = "001") 50 | c_p.append(p) 51 | except: 52 | pass 53 | if len(c_p) >= n_total: 54 | break 55 | for read_h,signal,abs_path,read_id in fast5_iter(modified_fast5,mode = 'r',tqdm_bar = True): 56 | try: 57 | p = read_entry(read_h,entry = "ModifiedProbability",index = "001") 58 | m_p.append(p) 59 | except: 60 | pass 61 | if len(m_p) >= n_total: 62 | break 63 | c_p = np.hstack(c_p) 64 | m_p = np.hstack(m_p) 65 | for t in np.arange(0,1.0001,0.002): 66 | TP.append(posterior_decode(m_p,M_threshold = t)) 67 | FP.append(posterior_decode(c_p,M_threshold = t)) 68 | fig,axs = plt.subplots(figsize = (5,5)) 69 | auc_plot(TP,FP,axs = axs) 70 | fig.savefig(os.path.join(output,"roc%s.png"%(label))) 71 | np.save(os.path.join(output,"TP%s.npy"%(label)),TP) 72 | np.save(os.path.join(output,"FP%s.npy"%(label)),FP) 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument("-c","--control",type = str, required = True,help = "Control fast5 file") 77 | parser.add_argument("-p","--positive",type = str,required = True,help = "Positive fast5 file") 78 | parser.add_argument("-o","--output",type = str,required = True, help = "Output directory") 79 | parser.add_argument("-l","--label",type = str,help = "Label for the output file") 80 | parser.add_argument("-n","--max_n",type = int,default = -1,help = "Maximum number of reads to use") 81 | args = parser.parse_args() 82 | run(args) -------------------------------------------------------------------------------- /xron/xron_annotate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed Jul 20 11:04:29 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import numpy as np 7 | from xron.utils.fastIO import read_fast5 8 | from xron.utils.seq_op import fast5_iter 9 | from itertools import groupby 10 | from tqdm import tqdm 11 | import argparse 12 | import sys 13 | 14 | def softmax(logits,axis = -1): 15 | return np.exp(logits)/np.sum(np.exp(logits),axis = axis,keepdims = True) 16 | 17 | def get_posterior(logits,seq,move,canonical_base = 'A',modified_base = 'M',base_list = ['A','C','G','T','M'],n_largest_p = 3): 18 | """Get the posterior probability of the canonical base and modified base""" 19 | pos = np.cumsum(move)-1 20 | posterior = [] 21 | for k,g in groupby(zip(logits,pos),lambda x:x[1]): 22 | g = np.asarray([x[0] for x in g]) 23 | if k >= len(seq): 24 | print("Warning, found sequence that is too short, probabily the result is from a overlay>0 basecall.") 25 | return None 26 | if seq[k] == canonical_base or seq[k] == modified_base: 27 | g = softmax(g) 28 | p_canonical = g[:,base_list.index(canonical_base)+1] 29 | p_canonical.sort() 30 | p_canonical = p_canonical[-n_largest_p:].sum() 31 | p_modified = g[:,base_list.index(modified_base)+1] 32 | p_modified.sort() 33 | p_modified = p_modified[-n_largest_p:].sum() 34 | posterior.append(p_modified/(p_canonical+p_modified)) 35 | return np.asarray(posterior) 36 | 37 | def write_modified_probability(args): 38 | fail_count = 0 39 | with tqdm() as t: 40 | for read_h,signal,abs_path,read_id in fast5_iter(args.fast5,mode = 'a'): 41 | t.postfix = "File: %s, Read: %s, failed: %d"%(abs_path,read_id,fail_count) 42 | try: 43 | logits,move,seq = read_fast5(read_h,index = args.basecall_entry) 44 | except: 45 | fail_count += 1 46 | continue 47 | mod_p = get_posterior(logits, 48 | seq, 49 | move, 50 | canonical_base = args.canonical_base, 51 | modified_base = args.modified_base, 52 | base_list = [x for x in args.alphabeta], 53 | n_largest_p = args.n_largest_p) 54 | if mod_p is None: 55 | fail_count += 1 56 | continue 57 | result_h = read_h['Analyses/Basecall_1D_%s/BaseCalled_template'%(args.basecall_entry)] 58 | if 'ModifiedProbability' in result_h: 59 | del result_h['ModifiedProbability'] 60 | result_h.create_dataset('ModifiedProbability',data = mod_p,dtype = "f") 61 | result_h['ModifiedProbability'].attrs['alphabet'] = args.alphabeta 62 | result_h['ModifiedProbability'].attrs['canonical_base'] = args.canonical_base 63 | result_h['ModifiedProbability'].attrs['modified_base'] = args.modified_base 64 | result_h['ModifiedProbability'].attrs['n_largest_p'] = args.n_largest_p 65 | t.update() 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('-i','--fast5',type = str,required = True,help = 'Path to the fast5 folder') 70 | parser.add_argument('--basecall_entry',type = str,default = '000',help = 'Basecall entry') 71 | parser.add_argument('--alphabeta',type = str,default = 'ACGTM',help = 'Alphabet of the basecall model. Default is ACGTM') 72 | parser.add_argument('--canonical_base',type = str,default = 'A',help = 'Canonical base. Default is A') 73 | parser.add_argument('--modified_base',type = str,default = 'M',help = 'Modified base. Default is M') 74 | parser.add_argument('--n_largest_p',type = int,default = 3,help = 'Number of largest probability to use. Default is 3') 75 | args = parser.parse_args(sys.argv[1:]) 76 | write_modified_probability(args) -------------------------------------------------------------------------------- /xron/entry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Xron Authors. All Rights Reserved. 2 | # 3 | #This Source Code Form is subject to the terms of the Mozilla Public 4 | #License, v. 2.0. If a copy of the MPL was not distributed with this 5 | #file, You can obtain one at http://mozilla.org/MPL/2.0/. 6 | # 7 | #Created on Mon Aug 14 18:38:18 2017 8 | import argparse 9 | import sys 10 | import logging 11 | from os import path 12 | import xron 13 | from xron import xron_eval 14 | from xron import xron_train_supervised 15 | from xron import xron_init 16 | from xron.utils import prepare_chunk 17 | from xron.nrhmm import hmm_relabel 18 | 19 | def check_init(init_function): 20 | def decorator(function): 21 | def wrapper(*args, **kwargs): 22 | init_function() 23 | result = function(*args, **kwargs) 24 | return result 25 | return wrapper 26 | return decorator 27 | 28 | def check_model(): 29 | model_path = path.join(xron.__path__[0], 'models') 30 | if not path.exists(model_path): 31 | logging.error('Models not found. Please run "xron init" first.') 32 | sys.exit(1) 33 | 34 | @check_init(check_model) 35 | def evaluation(args): 36 | xron_eval.post_args(args) 37 | xron_eval.main(args) 38 | 39 | @check_init(check_model) 40 | def export(args): 41 | prepare_chunk.post_args(args) 42 | prepare_chunk.extract(args) 43 | 44 | def train(args): 45 | xron_train_supervised.post_args(args) 46 | xron_train_supervised.main(args) 47 | 48 | @check_init(check_model) 49 | def relabel(args): 50 | hmm_relabel.post_args(args) 51 | hmm_relabel.main(args) 52 | 53 | def main(arguments=sys.argv[1:]): 54 | parser = argparse.ArgumentParser(prog='xron', description='A deep neural network basecaller that achieve methylation.') 55 | parser.add_argument('-v','--version',action='version',version='Xron version '+xron.__version__,help="Print out the version.") 56 | subparsers = parser.add_subparsers(title='sub command', help='sub command help') 57 | 58 | # parser for 'init' command 59 | parser_init = subparsers.add_parser('init', description='Initialize xron package, need to run this when first time runnning xron', 60 | help='Initialize xron package, need to run this when first time runnning xron') 61 | parser_init.set_defaults(func=xron_init.get_models) 62 | 63 | # parser for 'call' command 64 | parser_call = subparsers.add_parser('call', description='Perform basecalling', help='Perform basecalling.') 65 | xron_eval.add_arguments(parser_call) 66 | parser_call.set_defaults(func=evaluation) 67 | 68 | # parser for 'extract' command 69 | parser_export = subparsers.add_parser('prepare', description='Prepare the training dataset by aligning it to the reference genome, it is an equivalent command to resquiggle when --extract_seq flag is set.', 70 | help='Realign the sequence to the reference genome and extract the signal chunk and label for training.') 71 | prepare_chunk.add_arguments(parser_export) 72 | parser_export.set_defaults(func=export) 73 | 74 | # parser for 'relabel' command 75 | parser_relabel = subparsers.add_parser('relabel', description='Relabel the training dataset using the pretraiend NHMM model.', 76 | help='Relabel the training dataset using the pretraiend NHMM model.') 77 | hmm_relabel.add_arguments(parser_relabel) 78 | parser_relabel.set_defaults(func=relabel) 79 | 80 | # parser for 'train' command 81 | parser_train = subparsers.add_parser('train', description='Training a model in several ways: embedding, supervised and reinforce', help='Train a model.') 82 | xron_train_supervised.add_arguments(parser_train) 83 | parser_train.set_defaults(func=train) 84 | 85 | args = parser.parse_args(arguments) 86 | if hasattr(args, 'func'): 87 | args.func(args) 88 | else: 89 | parser.print_help() 90 | 91 | 92 | if __name__ == '__main__': 93 | logging.basicConfig(level=logging.INFO) 94 | #print(sys.argv[1:]) 95 | main() -------------------------------------------------------------------------------- /xron/utils/index_multi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Oct 3 11:47:48 2022 5 | This script is slower than single thread index.py, probably because File IO is the bottleneck. 6 | use index.py instead. 7 | @author: heavens 8 | """ 9 | import os 10 | import h5py 11 | import argparse 12 | from xron.utils.seq_op import fast5_shallow_iter 13 | from xron.utils.fastIO import read_fastq 14 | import multiprocessing as mp 15 | from time import sleep 16 | from multiprocessing import Process 17 | from multiprocessing import Manager 18 | from multiprocessing import Queue 19 | from tqdm import tqdm 20 | import queue #In Python 2.7, this would be Queue 21 | 22 | DESIRED_TIMEOUT = 1 23 | MAX_result_dict_SIZE = 1000#This is the number of instances in the queue, the number should be set in the way that the max queue byte size smaller than the page capacity. 24 | def worker(job_queue,fastqs,result_queue,log): 25 | while True: 26 | try: 27 | i = job_queue.get(timeout = DESIRED_TIMEOUT) 28 | except queue.Empty: 29 | return 30 | #Do the job here 31 | try: 32 | with h5py.File(i,mode = 'r') as root: 33 | read_ids = list(root.keys()) 34 | for id in read_ids: 35 | if id[5:] in fastqs['name']: 36 | result_queue.put((id[5:],i)) 37 | log['success'] += 1 38 | except Exception as e: 39 | print("Reading %s failed due to %s."%(i,e)) 40 | log['fail'] += 1 41 | continue 42 | 43 | def run(args): 44 | manager = Manager() 45 | filequeue = manager.Queue() 46 | result_queue = manager.Queue() 47 | log = manager.dict() 48 | log['fail'] = 0 49 | log['success'] = 0 50 | 51 | max_threads_number = mp.cpu_count()-1 #1 thread is used for the main process 52 | all_proc = [] 53 | fastq_records = read_fastq(args.fastq) 54 | fast5_records = {} 55 | file_number = 0 56 | print("Read in fast5 file list.") 57 | for (dirpath, dirnames, filenames) in os.walk(args.fast5+'/'): 58 | for filename in filenames: 59 | if not filename.endswith('fast5'): 60 | continue 61 | abs_path = os.path.join(dirpath,filename) 62 | filequeue.put(abs_path) 63 | file_number += 1 64 | print("Create read id to fast5 file mapping.") 65 | for i in range(max_threads_number if args.threads is None else args.threads): 66 | p = Process(target = worker, args = (filequeue,fastq_records,result_queue,log)) 67 | all_proc.append(p) 68 | p.start() 69 | 70 | print("Indexing fastq files.") 71 | with open(args.fastq+'.index', 'w+') as f: 72 | with tqdm() as t: 73 | t.total = len(fastq_records['name']) 74 | while log['success'] + log['fail'] < file_number or not filequeue.empty() or not result_queue.empty(): 75 | try: 76 | #pop out result from the dict 77 | result = result_queue.get() 78 | if result[0] in fastq_records['name']: 79 | f.write(result[0]+'\t'+result[1]+'\n') 80 | fastq_records['name'].pop(fastq_records['name'].index(result[0])) 81 | t.update() 82 | except KeyError: 83 | sleep(0.1) 84 | continue 85 | if len(fastq_records['name']) != 0: 86 | raise ValueError('%d fastq readid not found in fast5'%(len(fastq_records['name']))) 87 | 88 | for p in all_proc: 89 | p.join() 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser(description='Process some integers.') 93 | parser.add_argument('--fast5', required = True, type=str, help='folder that contains fast5 output') 94 | parser.add_argument('--fastq', required = True, type=str, help='The merged fastq file') 95 | parser.add_argument('--threads', type=int, default = None, help='The number of threads used for indexing') 96 | args = parser.parse_args() 97 | run(args) -------------------------------------------------------------------------------- /xron/utils/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from xron.utils.seq_op import raw2seq 4 | from torchaudio.functional import forced_align 5 | 6 | def beam_search(decoder, logits, seq_len, return_paths = True,vocab = ['b','A','C','G','T','M']): 7 | results = decoder(logits, seq_len) 8 | tokens = [ [x.tokens for x in result] for result in results] 9 | seqs = [[''.join([vocab[x] for x in t]) for t in token] for token in tokens] 10 | if return_paths: 11 | moves, path_logit = force_align_batch(logits, tokens,device = "cpu") 12 | return tokens, moves, path_logit 13 | return tokens 14 | 15 | 16 | def force_align_batch(log_probs, targets,device = 'cpu'): 17 | """A batch wrapper for force_align function 18 | force_align function can only take batch 1 as current version of torchaudio: 2.4.0 19 | log_probs: A tensor of shape [N,L,C] N - batch_size, L - sequence length, C - number of classes 20 | targets: A nested lists, where the first dimension is the batch size, and the second dimesnion 21 | is n_best paths, and the third dimension is the tokens 22 | 23 | """ 24 | moves, path_logits = [],[] 25 | N,L,C = log_probs.shape 26 | log_probs = log_probs.to(device) 27 | for i in range(N): 28 | log_prob = log_probs[i] 29 | target = targets[i] 30 | curr_m, curr_p = [],[] 31 | for t in target: 32 | t_len = torch.tensor([len(t)]).to(device) 33 | if not torch.is_tensor(t): 34 | t = torch.tensor(t) 35 | t = t.to(device) 36 | if len(t) == 0: 37 | align = torch.zeros(L,dtype = torch.bool) 38 | pl = torch.zeros(L,dtype = torch.float32) 39 | else: 40 | align,pl = forced_align(log_probs = log_prob.unsqueeze(0), 41 | targets = t.unsqueeze(0), 42 | target_lengths=t_len) 43 | move = (align>0) # [1,L] 44 | curr_p.append(pl[move]) 45 | #pad move to the same length as log_prob 46 | move = move.squeeze(0) # [L] 47 | move = torch.cat([move,torch.zeros(L-len(move),dtype = torch.bool)]) 48 | curr_m.append(move.to(int).numpy()) 49 | moves.append(curr_m) 50 | path_logits.append(curr_p) 51 | return moves, path_logits 52 | 53 | def viterbi_decode(logits:torch.tensor): 54 | """ 55 | Viterbi decdoing algorithm 56 | 57 | Parameters 58 | ---------- 59 | logits : torch.tensor 60 | Shape L-N-C 61 | 62 | Returns 63 | ------- 64 | sequence: A length N list contains final decoded sequence. 65 | moves: A length N list contains the moves array. 66 | 67 | """ 68 | sequence = np.argmax(logits,axis = 2) 69 | sequence = sequence.T #L,N -> N,L 70 | sequence,moves = raw2seq(sequence) 71 | return sequence,list(moves.astype(int)) 72 | 73 | if __name__ == "__main__": 74 | from time import time 75 | from fast_ctc_decode import fast_beam_search 76 | from torchaudio.models.decoder import cuda_ctc_decoder,ctc_decoder 77 | T = 1000 78 | logits = torch.randn(5, T, 6) # Example with batch size of 1, 100 time steps, and 6 classes 79 | logits[0,5:20,0] += 5 80 | 81 | # Convert logits to log probabilities 82 | log_probs = torch.nn.functional.log_softmax(logits, dim=-1) 83 | 84 | # Length of each sequence in the batch (assume all sequences are of max length for simplicity) 85 | seq_len = torch.tensor([T]*5, dtype=torch.int32) 86 | 87 | # Vocabulary (example) 88 | vocab = ['b','A','C','G','T','M'] # Example vocabulary 89 | 90 | # Initialize the CTC beam search decoder 91 | beam_search_decoder = cuda_ctc_decoder( 92 | tokens = vocab, 93 | nbest = 5, 94 | beam_size = 10 95 | ) 96 | beam_search_decoder_cpu = ctc_decoder( 97 | lexicon = None, 98 | tokens = vocab, 99 | nbest = 5, 100 | beam_size = 10, 101 | blank_token = 'b', 102 | sil_token = 'b' 103 | ) 104 | log_probs_cuda = log_probs.to('cuda') 105 | seq_len_cuda = seq_len.to('cuda') 106 | start = time.time() 107 | results_torch, tokens = beam_search(beam_search_decoder, log_probs_cuda, seq_len_cuda, return_paths = False) 108 | print("Elapsed time cuda beam search:",time.time()-start) 109 | start = time.time() 110 | results_ont,paths = fast_beam_search(logits) 111 | print("Elapsed time fast beam search:",time.time()-start) 112 | start = time.time() 113 | results_cpu, tokens_cpu = beam_search(beam_search_decoder_cpu, log_probs, seq_len, return_paths = False) 114 | print("Elapsed time cpu beam search:",time.time()-start) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![xron_logo](https://github.com/haotianteng/Xron/blob/master/docs/images/xron_logo.png) 2 | Xron (ˈkairɑn) is a methylation basecaller that could identify m6A methylation modification from ONT direct RNA sequencing. 3 | Using a deep learning CNN+RNN+CTC structure to establish end-to-end basecalling for the nanopore sequencer. 4 | The name is inherited from [Chiron](https://github.com/haotianteng/Chiron) 5 | Built with **PyTorch** and python 3.8+ 6 | 7 | If you found Xron useful, please consider to cite: 8 | > Teng, H., Stoiber, M., Bar-Joseph, Z. and Kingsford, C., 2024. [Detecting m6A RNA modification from nanopore sequencing using a semi-supervised learning framework.](https://www.biorxiv.org/content/10.1101/2024.01.06.574484v1.full.pdf) bioRxiv. *Genome Research*, in press. 9 | 10 | If you encounter any issue during using Xron, please submit an issue in the repository. 11 | 12 | m6A-aware RNA basecall one-liner: 13 | ``` 14 | xron call -i -o -m models/ENEYFT --boostnano 15 | ``` 16 | 17 | A basecaller for SQK-RNA004 nanopore kit is provided now! To use it: 18 | ``` 19 | xron call -i -o -m models/RNA004 20 | ``` 21 | 22 | ### High accuracy on direct-RNA004 sequencing kit 23 | ![RNA004](https://github.com/user-attachments/assets/5179540a-afed-4c62-86cf-42c831243e0c) 24 | 25 | ### Asynchronous m6A modifications in non-coding regions identified using RNA-004 data 26 | ![modification_status](https://github.com/user-attachments/assets/4178ecfa-9597-40bf-b5d7-b0bab72974f1) 27 | 28 | 29 | --- 30 | ## Table of contents 31 | 32 | - [Table of contents](#table-of-contents) 33 | - [Install](#install) 34 | - [Install from Source](#install-from-source) 35 | - [Install from Pypi](#install-from-pypi) 36 | - [Basecall](#basecall) 37 | - [Segmentation using NHMM](#segmentation-using-nhmm) 38 | - [Prepare chunk dataset](#prepare-chunk-dataset) 39 | - [Realign the signal using NHMM.](#realign-the-signal-using-nhmm) 40 | - [Training](#training) 41 | 42 | ## Install 43 | For either installation method, recommend to create a vritual environment first using conda or venv, take conda for example, there is a known compiling issue for installation with Python > 3.8, so pleasd installed with Python 3.8. 44 | ```bash 45 | conda create --name YOUR_VIRTUAL_ENVIRONMENT python=3.8 46 | conda activate YOUR_VIRTUAL_ENVIRONMENT 47 | ``` 48 | Then you can install from our pypi repository or install the newest version from github repository. 49 | 50 | ### Install 51 | ```bash 52 | pip install xron 53 | ``` 54 | Xron requires at least PyTorch 1.11.0 to be installed. If you have not yet installed PyTorch, install it via guide from [official repository](https://pytorch.org/get-started/locally/). 55 | ## Basecall 56 | Before running basecall using Xron, you need to download the models from our AWS s3 bucket by running **xron init** 57 | ```bash 58 | xron init 59 | ``` 60 | This will automatically download the models and put them into the *models* folder. 61 | We provided sample code in xron-samples folder to achieve m6A-aware basecall and identify m6A site. 62 | To run xron on raw fast5 files: 63 | ``` 64 | xron call -i ${INPUT_FAST5} -o ${OUTPUT} -m models/ENEYFT --fast5 --beam 50 --chunk_len 2000 65 | ``` 66 | 67 | ## Segmentation using NHMM 68 | ### Prepare chunk dataset 69 | Xron also include a non-homegeneous HMM (NHMM) for signal re-sqquigle. To use it: 70 | Firstly we need to extract the chunk and basecalled sequence using **prepare** module 71 | ```bash 72 | xron prepare -i ${FAST5_FOLDER} -o ${CHUNK_FOLDER} --extract_seq --basecaller guppy --reference ${REFERENCE} --mode rna_meth --extract_kmer -k 5 --chunk_len 4000 --write_correction 73 | ``` 74 | Replace the FAST5_FOLDER, CHUNK_FOLDER and REFERENCE with your basecalled fast5 file folder, your output folder and the path to the reference genome fasta file. 75 | 76 | ### Realign the signal using NHMM. 77 | Then run the NHMM to realign ("resquiggle") the signal. 78 | ```bash 79 | xron relabel -i ${CHUNK_FOLDER} -m ${MODEL} --device $DEVICE 80 | ``` 81 | This will generate a paths.py file under CHUNK_FOLDER which gives the kmer segmentation of the chunks. 82 | 83 | ## Training 84 | To train a new Xron model using your own dataset, you need to prepare your own training dataset, the dataset should includes a signal file (chunks.npy), labelled sequences (seqs.npy) and sequence length for each read (seq_lens.npy), and then run the xron supervised training module 85 | ```bash 86 | xron train -i chunks.npy --seq seqs.npy --seq_len seq_lens.npy --model_folder OUTPUT_MODEL_FOLDER 87 | ``` 88 | Training Xron model from scratch is hard, I would recommend to fine-tune our model by specify --load flag, for example we can finetune the provided ENEYFT model (model trained using cross-linked ENE dataset and finetuned on Yeast dataset): 89 | ```bash 90 | xron train -i chunks.npy --seq seqs.npy --seq_len seq_lens.npy --model_folder models/ENEYFT --load 91 | ``` 92 | 93 | -------------------------------------------------------------------------------- /xron/nrhmm/tandem_repeat_resquiggle.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is used to resquiggle the reads that have different tandem repeat number than the reference. 3 | """ 4 | import os 5 | import toml 6 | import argparse 7 | import numpy as np 8 | from functools import partial 9 | from tqdm import tqdm 10 | from multiprocessing import Pool 11 | from multiprocessing import Queue 12 | from xron.nrhmm.kmer2seq import fixing_looping_path 13 | import multiprocessing as mp 14 | from time import sleep 15 | from multiprocessing import Process 16 | from multiprocessing import Manager 17 | from multiprocessing import Queue 18 | from tqdm import tqdm 19 | import queue #In Python 2.7, this would be Queue 20 | 21 | DESIRED_TIMEOUT = 1 22 | MAX_RESULT_QUEUE_SIZE = 1000 23 | SLEEPING_TIME = 0.01 24 | def worker(job_queue,args,result_queue): 25 | while True: 26 | try: 27 | i,seq,path,duration = job_queue.get(timeout = DESIRED_TIMEOUT) 28 | except queue.Empty: 29 | return 30 | #Do the job here 31 | if len(seq) < args.min_seq_len: 32 | result_queue.put({i:(None,None)}) 33 | else: 34 | fixed_path, fixing = fixing_looping_path_func(path[:duration],seq) 35 | # while result_queue.qsize() > MAX_RESULT_QUEUE_SIZE: 36 | # sleep(SLEEPING_TIME) 37 | result_queue.put({i:(fixed_path,fixing)}) 38 | 39 | def main(args): 40 | seqs,path,durations,config = load_data(args.input_folder) 41 | manager = Manager() 42 | filequeue = manager.Queue() 43 | result_queue = manager.Queue(maxsize = MAX_RESULT_QUEUE_SIZE) 44 | results_dict = {"Fixed":0, "No need to fix":0, "Fix failed":0} 45 | max_threads_number = mp.cpu_count()-1 if args.threads is None else args.threads #1 thread is used for the main process 46 | all_proc = [] 47 | file_number = len(path) 48 | for i in range(file_number): 49 | filequeue.put((i,seqs[i],path[i],durations[i])) 50 | for i in range(max_threads_number): 51 | p = Process(target = worker, args = (filequeue,args,result_queue)) 52 | all_proc.append(p) 53 | p.start() 54 | to_finished = list(np.arange(file_number)) 55 | with tqdm() as t: 56 | t.total = file_number 57 | t.set_description("Fixing the path of control data.") 58 | while len(to_finished): 59 | try: 60 | result = result_queue.get(timeout = DESIRED_TIMEOUT) 61 | i = list(result.keys())[0] 62 | fixed_path, fixing = result[i] 63 | if fixing is None: 64 | results_dict["Fix failed"] += 1 65 | elif fixing: 66 | results_dict['Fixed'] += 1 67 | path[i][:durations[i]] = fixed_path 68 | else: 69 | results_dict['No need to fix'] += 1 70 | elements_in_queue = result_queue.qsize() 71 | t.set_description(f"Fixed {results_dict['Fixed']}, No need to fix {results_dict['No need to fix']}, Skip because sequence is too short {results_dict['Fix failed']}, elements in queue {elements_in_queue}") 72 | if i in to_finished: 73 | to_finished.remove(i) 74 | else: 75 | raise ValueError(f"The index {i} is not in the to_finished list.") 76 | t.update() 77 | except queue.Empty: 78 | continue 79 | except Exception as e: 80 | print(e) 81 | break 82 | print("Saving the delooped result.") 83 | np.save(os.path.join(args.input_folder,"path_fix.npy"),path) 84 | print("Finished") 85 | for i,p in enumerate(all_proc): 86 | p.join() 87 | 88 | def load_data(input_folder): 89 | seqs = np.load(os.path.join(input_folder,"seqs.npy")) 90 | path = np.load(os.path.join(input_folder,"path.npy")) 91 | duration = np.load(os.path.join(input_folder,"durations.npy")) 92 | config = toml.load(os.path.join(input_folder,"config.toml")) 93 | return seqs,path,duration,config 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("-i","--input_folder",help = "The folder that contains the data to be fixed.") 98 | parser.add_argument("--modified_base",default = "M",help = "The modified base.") 99 | parser.add_argument("--canonical_base",default = "A",help = "The canonical base.") 100 | parser.add_argument("--threads",default = None,type = int,help = "The number of threads to use, default is None which means using all the threads.") 101 | parser.add_argument("--min_seq_len",default = 7,type = int,help = "The minimum length of the sequence to be fixed.") 102 | args = parser.parse_args() 103 | config = toml.load(os.path.join(args.input_folder,"config.toml")) 104 | fixing_looping_path_func = partial(fixing_looping_path,idx2kmer = config['idx2kmer'],modified_base=args.modified_base,canonical_base=args.canonical_base) 105 | main(args) -------------------------------------------------------------------------------- /xron/nrhmm/profile.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Aug 9 07:09:45 2022 5 | 6 | @author: heavens 7 | """ 8 | import os 9 | import csv 10 | import toml 11 | import itertools 12 | from tqdm import tqdm 13 | import numpy as np 14 | 15 | 16 | class Profiler(object): 17 | def __init__(self,data_f, config): 18 | self.data_f = data_f 19 | print("Loading chunk data") 20 | chunks_f = os.path.join(data_f,"chunks.npy") 21 | self.chunks = np.load(chunks_f) 22 | self.chunk_flatten = self.chunks.flatten() 23 | print("Loading decoded path data") 24 | decoded_f = os.path.join(data_f,"path.npy") 25 | self.decoded = np.load(decoded_f) 26 | self.decoded_flatten = self.decoded.flatten() 27 | self.config = config 28 | self.idx2kmer = config['idx2kmer'] 29 | self.kmer2idx = config['kmer2idx_dict'] 30 | self.n_kmer = len(self.idx2kmer) 31 | self.alphabeta = config['alphabeta'] 32 | self.n_base = len(self.alphabeta) 33 | self.transitions_count = [[[]for _ in np.arange(self.n_base)] for _ in np.arange(self.n_kmer)] 34 | self.transition_weight = np.zeros((self.n_kmer,self.n_base)) 35 | self.epsilon = 1e-6 36 | assert np.equal(self.chunks.shape,self.decoded.shape).all() 37 | group_f = os.path.join(self.data_f,"grouped.npy") 38 | if os.path.isfile(group_f): 39 | print("Loading grouped array.") 40 | self.kmer_groups = np.load(group_f,allow_pickle = True) 41 | else: 42 | self.kmer_groups = None 43 | 44 | def build_invariant_kmers(self, variant_bases): 45 | self.invariant_kmers = [] 46 | for i in np.arange(self.n_kmer): 47 | if not any([x in self.idx2kmer[i] for x in variant_bases]): 48 | self.invariant_kmers.append(i) 49 | 50 | def grouping(self): 51 | print("Sorting decoded path") 52 | argsort = np.argsort(self.decoded_flatten) 53 | sorted_chunks = self.chunk_flatten[argsort] 54 | print("Grouping signal") 55 | kmers,idxs = np.unique(self.decoded_flatten[argsort],return_index = True) 56 | grouped = np.split(sorted_chunks,idxs[1:]) 57 | self.kmer_groups = [[] for _ in np.arange(self.n_kmer)] 58 | for i,kmer in enumerate(kmers): 59 | self.kmer_groups[kmer] = grouped[i] 60 | print("Writing summary") 61 | np.save(os.path.join(self.data_f,"grouped.npy"),self.kmer_groups) 62 | 63 | def summarize(self): 64 | if self.kmer_groups is None: 65 | raise ValueError("Grouping file has not been found, please run grouping first.") 66 | self.means = np.asarray([np.mean(self.kmer_groups[i]) for i in tqdm(np.arange(self.n_kmer))]) 67 | self.stds = np.asarray([np.std(self.kmer_groups[i]) for i in tqdm(np.arange(self.n_kmer))]) 68 | 69 | def summarize_length(self): 70 | flatten_path = self.decoded.flatten() 71 | self.kmer_length = {} 72 | for i,kmer in tqdm(enumerate(self.idx2kmer)): 73 | condition = flatten_path == i 74 | self.kmer_length[kmer] = np.diff(np.where(np.concatenate(([condition[0]], 75 | condition[:-1] != condition[1:], 76 | [True])))[0])[::2] 77 | 78 | def summarize_transition(self): 79 | for path in self.decoded: 80 | keys,counts = [],[] 81 | for key,group in itertools.groupby(path): 82 | keys.append(key) 83 | counts.append(len(list(group))) 84 | for i,k in enumerate(keys): 85 | if i == len(keys) - 1: 86 | break 87 | next_base = self.idx2kmer[keys[i+1]][-1] 88 | self.transitions_count[k][self.alphabeta.index(next_base)].append(counts[i]) 89 | for i in np.arange(self.n_kmer): 90 | for j in np.arange(self.n_base): 91 | c = self.transitions_count[i][j] 92 | self.transition_weight[i][j] = len(c)/(sum(c)+self.epsilon) #ML estimation of transition probability is N/sum_i k_i - N respect that the stay probability is 1, see the notebook for more details. 93 | 94 | if __name__ == "__main__": 95 | home_f = os.path.expanduser("~") 96 | control_folder = home_f + "/bridge_scratch/ime4_Yearst/IVT/control/rep1/kmers_guppy_4000_noise" 97 | m6a_folder = home_f + "/bridge_scratch/ime4_Yearst/IVT/m6A/rep2/kmers_guppy_4000_noise" 98 | config = toml.load(os.path.join(control_folder,"config.toml")) 99 | p_control= Profiler(control_folder,config) 100 | # p_control.summarize() 101 | # p_control.summarize_length() 102 | p_control.summarize_transition() 103 | p_m6A= Profiler(m6a_folder,config) 104 | # p_m6A.summarize() 105 | # p_m6A.summarize_length() 106 | p_m6A.summarize_transition() 107 | 108 | -------------------------------------------------------------------------------- /xron/xron_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun Dec 26 19:38:25 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | import sys 8 | import torch 9 | import argparse 10 | from functools import partial 11 | from xron.xron_model import CONFIG,DECODER_CONFIG,CRITIC_CONFIG,MM_CONFIG 12 | from xron.xron_train_supervised import main as supervised_train 13 | from xron.xron_train_variational import main as reinforce_train 14 | from xron.xron_train_embedding import main as embedding_train 15 | 16 | optimizers = {'Adam':torch.optim.Adam, 17 | 'AdamW':torch.optim.AdamW, 18 | 'SGD':torch.optim.SGD, 19 | 'RMSprop':torch.optim.RMSprop, 20 | 'Adagrad':torch.optim.Adagrad, 21 | 'Momentum':partial(torch.optim.SGD,momentum = 0.9)} 22 | 23 | def main(args): 24 | class CTC_CONFIG(MM_CONFIG): 25 | CTC = {"beam_size":5, 26 | "beam_cut_threshold":0.05, 27 | "alphabeta": "ACGTM", 28 | "mode":"rna"} 29 | 30 | class TRAIN_EMBEDDING_CONFIG(CTC_CONFIG): 31 | TRAIN = {"inital_learning_rate":args.lr, 32 | "batch_size":args.batch_size, 33 | "grad_norm":2, 34 | "epsilon":0.1, 35 | "epsilon_decay":0, 36 | "alpha":0.01, #Entropy loss scale factor 37 | "keep_record":5, 38 | "decay":args.decay, 39 | "diff_signal":args.diff} 40 | 41 | class TRAIN_SUPERVISED_CONFIG(CTC_CONFIG): 42 | TRAIN = {"inital_learning_rate":args.lr, 43 | "batch_size":args.batch_size, 44 | "grad_norm":2, 45 | "keep_record":5, 46 | "eval_size":10000, 47 | "optimizer":optimizers[args.optimizer]} 48 | 49 | class TRAIN_REINFORCE_CONFIG(CTC_CONFIG): 50 | TRAIN = {"inital_learning_rate":args.lr, 51 | "batch_size":args.batch_size, 52 | "grad_norm":2, 53 | "epsilon":0.1, 54 | "epsilon_decay":0, 55 | "alpha":0.01, #Entropy loss scale factor 56 | "beta": 1., #Reconstruction loss scale factor 57 | "gamma":0, #Alignment loss scale factor 58 | "preheat":5000, 59 | "keep_record":5, 60 | "decay":args.decay, 61 | "diff_signal":args.diff} 62 | train_config = {"Embedding":TRAIN_EMBEDDING_CONFIG, 63 | "Supervised":TRAIN_SUPERVISED_CONFIG, 64 | "Reinforce":TRAIN_REINFORCE_CONFIG} 65 | train_module = {"Embedding":embedding_train, 66 | "Supervised":supervised_train, 67 | "Reinforce":reinforce_train} 68 | args.config = train_config[args.module] 69 | train_module[args.module](args) 70 | 71 | def add_arguments(parser): 72 | parser.add_argument('--module', required = True, 73 | help = "The training module to call, can be Embedding, Supervised and Reinforce") 74 | parser.add_argument('-i', '--chunks', required = True, 75 | help = "The .npy file contain chunks.") 76 | parser.add_argument('-o', '--model_folder', required = True, 77 | help = "The folder to save folder at.") 78 | parser.add_argument('--seq', required = True, 79 | help="The .npy file contain the sequence.") 80 | parser.add_argument('--seq_len', required = True, 81 | help="The .npy file contain the sueqnece length.") 82 | parser.add_argument('--device', default = 'cuda', 83 | help="The device used for training, can be cpu or cuda.") 84 | parser.add_argument('--lr', default = 4e-3, type = float, 85 | help="Initial learning rate.") 86 | parser.add_argument('--batch_size', default = 200, type = int, 87 | help="Training batch size.") 88 | parser.add_argument('--epoches', default = 10, type = int, 89 | help = "The number of epoches to train.") 90 | parser.add_argument('--report', default = 20, type = int, 91 | help = "The interval of training rounds to report.") 92 | parser.add_argument('--load', dest='retrain', action='store_true', 93 | help='Load existed model.') 94 | parser.add_argument('--config', default = None, 95 | help = "Training configuration.") 96 | parser.add_argument('--optimizer', default = "RMSprop", 97 | help = "Optimizer to use, can be Adam, AdamW, SGD and RMSprop,\ 98 | default is RMSprop") 99 | parser.add_argument('--threads', type = int, default = None, 100 | help = "Number of threads used by Pytorch") 101 | 102 | def post_args(args): 103 | os.makedirs(args.model_folder,exist_ok=True) 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser( 107 | description='Calling training module.') 108 | add_arguments(parser) 109 | args = parser.parse_args(sys.argv[1:]) 110 | main(args) -------------------------------------------------------------------------------- /xron/xron_index_shelve.py: -------------------------------------------------------------------------------- 1 | #Build an index dictionary for sequence position -> reference position & signal position 2 | import re 3 | import pysam 4 | import pod5 5 | import pysam 6 | import pickle 7 | import shelve 8 | import zlib 9 | from collections import defaultdict 10 | from tqdm import tqdm 11 | import numpy as np 12 | DATA_TYPE = np.int32 13 | 14 | def get_hc(read): 15 | cigar_string = read.cigarstring 16 | # Find all occurrences of a number followed by 'H' at the start or end 17 | start_match = re.match(r'^(\d+)H', cigar_string) 18 | end_match = re.search(r'.(\d+)H$', cigar_string) 19 | 20 | start_h = int(start_match.group(1)) if start_match else 0 21 | end_h = int(end_match.group(1)) if end_match else 0 22 | 23 | return start_h, end_h 24 | 25 | def find_bases(move_table): 26 | """Find an index mapping for the sequence position to the signal position 27 | from the move table 28 | Args: 29 | move_table: np.array, the move table from the mv:b:c tag, e.g. [1,0,0,1,1,0,0,1,...] 30 | Returns: 31 | seq_idx: np.array, the sequence index 32 | sig_pos: list of tuple, the start and end of the signal position of the same sequence index 33 | """ 34 | seq_idx = np.cumsum(move_table,dtype = DATA_TYPE) 35 | changes = np.diff(seq_idx) 36 | changes = np.insert(changes, 0, 1) # Insert a change at the beginning 37 | indices = np.where(changes != 0)[0] 38 | 39 | starts = indices 40 | ends = np.append(indices[1:], len(move_table)) 41 | 42 | sig_pos = list(zip(starts, ends)) 43 | return np.unique(seq_idx)-1,sig_pos 44 | 45 | class Indexer(object): 46 | def __init__(self,compress = True,use_shelve = False): 47 | self.run_stat = defaultdict(int) 48 | self.compress = compress 49 | self.use_shelve = use_shelve 50 | 51 | def build(self, 52 | bam_f): 53 | sam_f = pysam.AlignmentFile(bam_f, "rb") 54 | out_f = bam_f + ".index" 55 | total_count = 0 56 | if self.use_shelve: 57 | index = shelve.open(out_f,'c') 58 | else: 59 | index = {} 60 | with tqdm(sam_f, mininterval=0.1) as t: 61 | for read in t: 62 | if read.is_unmapped: 63 | self.run_stat['Unaligned'] += 1 64 | continue 65 | if read.is_secondary: 66 | self.run_stat['Secondary alignment'] += 1 67 | continue 68 | ref_contig = read.reference_name 69 | qr_map = read.get_aligned_pairs(matches_only= False, with_seq = False) 70 | hc_start,hc_end = get_hc(read) 71 | read_id = read.query_name 72 | try: 73 | trim_length = read.get_tag("ts:i") 74 | move_table = read.get_tag("mv:B") 75 | except KeyError: 76 | self.run_stat['No moving table'] += 1 77 | if self.run_stat['No moving table']//(self.run_stat['Success']+1) > 100: 78 | raise ValueError("No moving table in the basecalled bam file, is --emit-moves tag used during basecall?") 79 | continue 80 | #get signal position by accumulating the move_table 81 | stride = move_table[0] 82 | seq_idx, sig_pos = find_bases(move_table[1:]) 83 | curr = {} 84 | total_count += len(seq_idx)*3 + len(qr_map)*2 85 | curr['seq2sig'] = {i:s for i,s in zip(seq_idx,sig_pos)} 86 | curr['sig2seq'] = [s[0]*stride+trim_length for s in sig_pos] 87 | curr['stride'] = stride 88 | curr['trim_length'] = trim_length 89 | curr['referece_name'] = ref_contig 90 | if read.is_reverse: 91 | hc_start,hc_end = hc_end,hc_start 92 | q_transform = lambda x: len(read.seq)-x-1 if x is not None else None 93 | else: 94 | q_transform = lambda x: x 95 | curr['query_reference_map'] = {q_transform(x[0]):x[1] for x in qr_map} 96 | curr['hard_clip'] = (hc_start,hc_end) 97 | if self.compress: 98 | curr = zlib.compress(pickle.dumps(curr)) 99 | index[read_id] = curr 100 | self.run_stat['Success'] += 1 101 | 102 | # early stop for testing 103 | # if self.run_stat['Success'] >= 10000: 104 | # break 105 | 106 | # Update run status 107 | t.set_postfix(self.run_stat,refresh = False) 108 | print(f"Estimaed size: {total_count*4/1024/1024} MB") 109 | if self.use_shelve: 110 | index.close() 111 | else: 112 | pickle.dump(index,open(out_f,"wb")) 113 | 114 | def args(): 115 | import argparse 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--bam",type = str,required = True,help = "The aligned bam file") 118 | parser.add_argument("--compress",action = "store_true",dest = "compress", 119 | help = "Compress the index entry.") 120 | parser.add_argument("--no-shelve",action = "store_false", dest = "use_shelve", 121 | help = "Enable shelve database.") 122 | return parser.parse_args() 123 | 124 | if __name__ == "__main__": 125 | import time 126 | args = args() 127 | start = time.time() 128 | input_bam = args.bam 129 | runner = Indexer(compress = args.compress,use_shelve = args.use_shelve) 130 | runner.build(input_bam) 131 | print(f"Time elapsed: {time.time()-start}") 132 | -------------------------------------------------------------------------------- /xron/nrhmm/transition_speed_test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Apr 28 18:03:07 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import torch 7 | import itertools 8 | import numpy as np 9 | import timeit 10 | 11 | def sparse_multiplication(transition,log_alpha): 12 | idxs = transition.indices() 13 | vals = transition.values().log() 14 | update_check = torch.zeros((batch_size,n_kmers),dtype = torch.bool) 15 | update = log_alpha[idxs[0],idxs[1]] +vals 16 | for i,u in enumerate(update): 17 | batch_i,kmer_i = idxs[0][i],idxs[2][i] 18 | if not update_check[batch_i,kmer_i]: 19 | curr_update = update[torch.logical_and(idxs[0]==batch_i,idxs[2] == kmer_i)] 20 | log_alpha[idxs[0][i],idxs[2][i]] = torch.logsumexp(curr_update,dim = 0) 21 | update_check[batch_i,kmer_i] = True 22 | log_alpha[torch.logical_not(update_check)] += np.log(1e-6) 23 | return log_alpha 24 | 25 | def gpu_multiplication(transition,log_alpha): 26 | dense_t = transition.to_dense() 27 | return log_domain_matmul(log_alpha, dense_t.log()) 28 | 29 | def base_multiplication(transition_base,indexing,log_alpha): 30 | n_kmer,n_base = indexing.shape 31 | alpha_shape = log_alpha.shape 32 | multiplication = log_alpha[:,indexing] 33 | return torch.logsumexp(transition_base.log() + multiplication,dim = -1) 34 | 35 | def log_domain_sparse_matmul(A:torch.sparse_coo_tensor, log_B:torch.Tensor, dim:int = -1): 36 | """ 37 | Do a sparse-dense tensor multiplication and reduced on the given dimension. 38 | 39 | Parameters 40 | ---------- 41 | A : torch.sparse_coo_tensor 42 | A sparse tensor with shape mxnxp. 43 | log_B : torch.Tensor 44 | A dense tensor in the log domain with same shape, broadcast is supported on dense tensor. 45 | dim : int, optional 46 | The dimension to perform reduction on A. The default is -1. 47 | 48 | Returns 49 | ------- 50 | A sparse tensor. 51 | 52 | """ 53 | A = A.coalesce() 54 | idxs = A.indices() 55 | log_vals = A.values().log() 56 | shape_A = torch.tensor(A.shape) 57 | shape_B = torch.tensor(log_B.shape) 58 | n_dims_A = idxs.shape[0] 59 | n_dims_B = len(shape_B) 60 | assert n_dims_A == n_dims_B, "Tensor has different number of dimensions." 61 | assert torch.all(shape_A >= shape_B), "Broadcast only supported on dense tensor." 62 | idxs_B = idxs.clone() 63 | remain_dims = np.arange(n_dims_A) 64 | remain_dims = np.delete(remain_dims,dim) 65 | remain_idxs = list(zip(*[idxs.cpu()[x].tolist() for x in remain_dims])) 66 | idxs_B[torch.where(shape_B==1)] = 0 67 | update = log_B[tuple(idxs_B)] + log_vals 68 | key_func = lambda x: x[1] 69 | update = update.tolist() 70 | update = sorted(zip(update,remain_idxs),key = key_func) 71 | nested = [ (k,list(g)) for k,g in itertools.groupby(update,key = key_func)] 72 | nested_vals = [[y[0] for y in x[1]] for x in nested] 73 | nested_idxs = [x[0] for x in nested] 74 | max_cols = max([len(x) for x in nested_vals]) 75 | padded = torch.tensor([x + [-np.inf]*(max_cols - len(x)) for x in nested_vals],device = A.device) 76 | return torch.sparse_coo_tensor(indices = list(zip(*nested_idxs)), 77 | values = torch.logsumexp(padded,dim = 1), 78 | size = shape_A[remain_dims].tolist(), 79 | device = A.device) 80 | 81 | 82 | def log_domain_matmul(log_A, log_B): 83 | """ 84 | log_A : m x n 85 | log_B : m x n x p or n x p 86 | output : m x p matrix 87 | 88 | Normally, a matrix multiplication 89 | computes out_{i,j} = sum_k A_{i,k} x B_{i,k,j} 90 | 91 | A log domain matrix multiplication 92 | computes out_{i,j} = logsumexp_k log_A_{i,k} + log_B_{i,k,j} 93 | """ 94 | dim_B = len(log_B.shape) 95 | log_A = log_A.unsqueeze(dim = 2) 96 | if dim_B == 2: 97 | log_B = log_B.unsqueeze(dim = 0) 98 | elementwise_sum = log_A + log_B 99 | out = torch.logsumexp(elementwise_sum, dim=1) 100 | return out 101 | 102 | if __name__ == "__main__": 103 | sparsity = 0.0014 104 | n_kmers = 3125 105 | batch_size = 2 106 | n_base = 5 107 | n_elements = int(n_kmers*n_kmers*sparsity) 108 | start = np.random.randint(low = 0, high = n_kmers - 1, size = n_elements) 109 | end = np.random.randint(low = 0, high = n_kmers - 1, size = n_elements) 110 | batch_i = np.random.randint(low = 0, high = batch_size - 1, size = n_elements) 111 | transition = torch.sparse_coo_tensor(indices = (batch_i,start,end),values = [1.]*len(start),size = (batch_size,n_kmers,n_kmers)).coalesce() 112 | transition_cuda = transition.to("cuda") 113 | log_alpha = torch.rand((batch_size,n_kmers)) 114 | log_alpha_cuda = log_alpha.to("cuda") 115 | result = log_domain_sparse_matmul(transition_cuda,log_alpha_cuda.unsqueeze(dim = 2),dim = 1) 116 | result_gpu = gpu_multiplication(transition_cuda,log_alpha_cuda) 117 | assert torch.all(torch.isclose(result.to_dense()[torch.logical_not(torch.isinf(result_gpu))], result_gpu[torch.logical_not(torch.isinf(result_gpu))])) 118 | # timeit.timeit(sparse_multiplication(transition, log_alpha),number = 10000) 119 | # timeit.timeit(gpu_multiplication(transition_cuda,log_alpha_cuda),number = 10000) 120 | transition_base = torch.rand((batch_size,n_kmers,n_base+1)).to("cuda") 121 | base_index = torch.randint(low = 0,high = n_kmers,size = (n_kmers,n_base + 1)).to("cuda") 122 | base_multiplication(transition_base,base_index,log_alpha_cuda) 123 | # timeit.timeit(base_multiplication(transition_base,base_index,log_alpha_cuda),number = 10000) -------------------------------------------------------------------------------- /xron/xron_index_lmdb.py: -------------------------------------------------------------------------------- 1 | # Build an index dictionary for sequence position -> reference position & signal position 2 | import re 3 | import pysam 4 | import pod5 5 | import pickle 6 | import zlib 7 | import lmdb 8 | from collections import defaultdict 9 | from tqdm import tqdm 10 | import numpy as np 11 | DATA_TYPE = np.int32 12 | START_H_PATTERN = re.compile(r'^(\d+)H') 13 | END_H_PATTERN = re.compile(r'.(\d+)H$') 14 | 15 | def get_hc(read): 16 | cigar_string = read.cigarstring 17 | # Find all occurrences of a number followed by 'H' at the start or end 18 | start_match = START_H_PATTERN.match(cigar_string) 19 | end_match = END_H_PATTERN.search(cigar_string) 20 | 21 | start_h = int(start_match.group(1)) if start_match else 0 22 | end_h = int(end_match.group(1)) if end_match else 0 23 | 24 | return start_h, end_h 25 | 26 | def find_bases(move_table): 27 | """Find an index mapping for the sequence position to the signal position 28 | from the move table 29 | Args: 30 | move_table: np.array, the move table from the mv:b:c tag, e.g. [1,0,0,1,1,0,0,1,...] 31 | Returns: 32 | seq_idx: np.array, the sequence index 33 | sig_pos: list of tuple, the start and end of the signal position of the same sequence index 34 | """ 35 | seq_idx = np.cumsum(move_table,dtype = DATA_TYPE) 36 | changes = np.diff(seq_idx, prepend=1) 37 | indices = np.where(changes != 0)[0] 38 | 39 | starts = indices 40 | ends = np.append(indices[1:], len(move_table)) 41 | 42 | sig_pos = list(zip(starts, ends)) 43 | return np.unique(seq_idx)-1, sig_pos 44 | 45 | class Indexer(dict): 46 | def __init__(self, compress=False): 47 | self.run_stat = defaultdict(int) 48 | self.compress = compress 49 | self.db = None 50 | 51 | def load(self,bam_f): 52 | index_f = bam_f + ".index" 53 | env = lmdb.open(index_f, readonly=True) 54 | self.db = env.begin() 55 | 56 | def __getitem__(self,key): 57 | if self.db is None: 58 | raise ValueError("No database has been loaded!") 59 | return pickle.loads(zlib.decompress(self.db[key])) 60 | 61 | def _write_to_lmdb(self, db,key,value): 62 | """ 63 | Write (key,value) to db 64 | """ 65 | success = False 66 | while not success: 67 | txn = db.begin(write=True) 68 | try: 69 | txn.put(key, value) 70 | txn.commit() 71 | success = True 72 | except lmdb.MapFullError: 73 | txn.abort() 74 | 75 | # double the map_size 76 | curr_limit = db.info()['map_size'] 77 | new_limit = curr_limit*2 78 | db.set_mapsize(new_limit) # double it 79 | 80 | def build(self, bam_f): 81 | sam_f = pysam.AlignmentFile(bam_f, "rb") 82 | out_f = bam_f + ".index" 83 | env = lmdb.open(out_f, map_size=int(1e11)) # Adjust map_size as needed 84 | with tqdm(sam_f, mininterval=0.1) as t: 85 | for read in t: 86 | if read.is_unmapped: 87 | self.run_stat['Unaligned'] += 1 88 | continue 89 | if read.is_secondary: 90 | self.run_stat['Secondary alignment'] += 1 91 | continue 92 | ref_contig = read.reference_name 93 | qr_map = read.get_aligned_pairs(matches_only=False, with_seq=False) 94 | hc_start, hc_end = get_hc(read) 95 | read_id = read.query_name 96 | try: 97 | trim_length = read.get_tag("ts:i") 98 | move_table = read.get_tag("mv:B") 99 | except KeyError: 100 | self.run_stat['No moving table'] += 1 101 | if self.run_stat['No moving table'] // (self.run_stat['Success'] + 1) > 100: 102 | raise ValueError("No moving table in the basecalled bam file, is --emit-moves tag used during basecall?") 103 | continue 104 | # Get signal position by accumulating the move_table 105 | stride = move_table[0] 106 | seq_idx, sig_pos = find_bases(move_table[1:]) 107 | curr = { 108 | 'seq2sig': {i: s for i, s in zip(seq_idx, sig_pos)}, 109 | 'sig2seq': [s[0] * stride + trim_length for s in sig_pos], 110 | 'stride': stride, 111 | 'trim_length': trim_length, 112 | 'reference_name': ref_contig 113 | } 114 | if read.is_reverse: 115 | hc_start, hc_end = hc_end, hc_start 116 | q_transform = lambda x: len(read.seq) - x - 1 if x is not None else None 117 | else: 118 | q_transform = lambda x: x 119 | curr['query_reference_map'] = {q_transform(x[0]): x[1] for x in qr_map} 120 | curr['hard_clip'] = (hc_start, hc_end) 121 | serailzed = pickle.dumps(curr) 122 | if self.compress: 123 | serailzed = zlib.compress(serailzed) 124 | self._write_to_lmdb(env, read_id.encode('utf-8'), serailzed) 125 | self.run_stat['Success'] += 1 126 | 127 | # Early stop for testing 128 | # if self.run_stat['Success'] >= 10000: 129 | # break 130 | 131 | # Update run status 132 | t.set_postfix(self.run_stat, refresh=False) 133 | 134 | def args(): 135 | import argparse 136 | parser = argparse.ArgumentParser() 137 | parser.add_argument("--bam", type=str, required=True, help="The aligned bam file") 138 | parser.add_argument("--compress", action="store_true", help="Compress the index entry") 139 | return parser.parse_args() 140 | 141 | if __name__ == "__main__": 142 | import time 143 | start = time.time() 144 | args = args() 145 | input_bam = args.bam 146 | runner = Indexer(compress=args.compress) 147 | runner.build(input_bam) 148 | print(f"Time elapsed: {time.time() - start}") 149 | -------------------------------------------------------------------------------- /xron/utils/fastIO.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Sep 3 14:35:26 2022 5 | 6 | @author: heavens 7 | """ 8 | import os 9 | import numpy as np 10 | import lmdb 11 | import shelve 12 | import pickle 13 | import zlib 14 | from abc import ABC, abstractmethod 15 | 16 | class Database(ABC): 17 | def __init__(self, compress = True): 18 | self.compress = compress 19 | self.db = None 20 | 21 | @abstractmethod 22 | def load(self,index_f): 23 | pass 24 | 25 | @abstractmethod 26 | def get_item(self,key): 27 | pass 28 | 29 | @abstractmethod 30 | def get_keys(self): 31 | pass 32 | 33 | @abstractmethod 34 | def get_values(self): 35 | pass 36 | 37 | @abstractmethod 38 | def get_items(self): 39 | pass 40 | 41 | class ShelveDatabase(Database): 42 | def get_item(self,key): 43 | content = self.db[key] 44 | if self.compress: 45 | return pickle.loads(zlib.decompress(content)) 46 | else: 47 | return pickle.loads(content) 48 | 49 | def load(self,index_f): 50 | self.db = shelve.open(index_f) 51 | 52 | def get_keys(self): 53 | return self.db.keys() 54 | 55 | def get_values(self): 56 | return self.db.values() 57 | 58 | def get_items(self): 59 | return self.db.items() 60 | 61 | def __del__(self): 62 | if self.db is not None: 63 | self.db.close() 64 | 65 | class LmdbDatabase(Database): 66 | def get_item(self,key): 67 | if isinstance(key,str): 68 | key = key.encode('utf-8') 69 | content = self.db.get(key) 70 | if content is None: 71 | raise KeyError(f"{key}") 72 | if self.compress: 73 | return pickle.loads(zlib.decompress(content)) 74 | else: 75 | return pickle.loads(content) 76 | 77 | def load(self,index_f): 78 | self.env = lmdb.open(index_f, readonly=True) 79 | self.db = self.env.begin() 80 | 81 | def get_keys(self): 82 | yield from self.db.cursor().iternext(keys=True,values=False) 83 | 84 | def get_values(self): 85 | yield from self.db.cursor().iternext(keys=False,values=True) 86 | 87 | def get_items(self): 88 | yield from self.db.cursor().iternext(keys=True,values=True) 89 | 90 | def __del__(self): 91 | if self.env is not None: 92 | self.env.close() 93 | 94 | def get_db_instance(backend = "shelve", compress = True): 95 | if backend == "shelve": 96 | return ShelveDatabase(compress=compress) 97 | elif backend == "lmdb": 98 | return LmdbDatabase(compress=compress) 99 | else: 100 | raise ValueError("Invalid backend") 101 | 102 | class Indexer(dict): 103 | def __init__(self, 104 | backend = "shelve", 105 | compress = True): 106 | super().__init__() 107 | self.backend = get_db_instance(backend,compress) 108 | 109 | def load(self,index_f): 110 | self.backend.load(index_f) 111 | 112 | def __getitem__(self,key): 113 | return self.backend.get_item(key) 114 | 115 | def __iter__(self): 116 | for key,val in self.backend.get_items(): 117 | #decode key if it is bytes 118 | if isinstance(key,bytes): 119 | key = key.decode('utf-8') 120 | if self.backend.compress: 121 | val = pickle.loads(zlib.decompress(val)) 122 | yield key,val 123 | 124 | def __repr__(self): 125 | return f"Indexer({self.backend})" 126 | 127 | def __str__(self): 128 | return f"Indexer({self.backend})" 129 | 130 | def __del__(self): 131 | self.backend.__del__() 132 | 133 | def read_fastqs(fastq_f): 134 | records = {"sequences":[],"name":[],"quality":[]} 135 | for fastq in os.listdir(fastq_f): 136 | with open(os.path.join(fastq_f,fastq),'r') as f: 137 | for line in f: 138 | if line.startswith("@"): 139 | records['name'].append(line.strip()[1:]) 140 | records['sequences'].append(next(f).strip()) 141 | assert next(f).strip() == "+" #skip the "+" 142 | records['quality'].append(next(f).strip()) 143 | return records 144 | 145 | def read_fastq(fastq): 146 | records = {"sequences":[],"name":[],"quality":[]} 147 | with open(fastq,'r') as f: 148 | for line in f: 149 | if line.startswith("@"): 150 | records['name'].append(line.strip()[1:]) 151 | records['sequences'].append(next(f).strip()) 152 | assert next(f).strip() == "+",print(line) #skip the "+" 153 | records['quality'].append(next(f).strip()) 154 | return records 155 | 156 | def read_fast5(read_h,index = "000"): 157 | result_h = read_h['Analyses/Basecall_1D_%s/BaseCalled_template'%(index)] 158 | logits = result_h['Logits'] 159 | move = result_h['Move'] 160 | try: 161 | seq = str(np.asarray(result_h['Fastq']).astype(str)).split('\n')[1] 162 | except: 163 | seq = np.asarray(result_h['Fastq']).tobytes().decode('utf-8').split('\n')[1] 164 | return np.asarray(logits),np.asarray(move),seq 165 | 166 | def read_entry(read_h,entry:str,index = "000"): 167 | """ 168 | Read a entry given the name 169 | 170 | """ 171 | result_h = read_h['Analyses/Basecall_1D_%s/BaseCalled_template'%(index)] 172 | return np.asarray(result_h[entry]) 173 | 174 | if __name__ == "__main__": 175 | lmdb_index = "/data/HEK293T_RNA004/aligned.sorted.bam.index" 176 | shelve_index = "/data/HEK293T_RNA004/index_test/aligned.sorted.bam.index" 177 | 178 | #test lmdb database 179 | indexer_lmdb = Indexer(backend = "lmdb") 180 | indexer_lmdb.load(lmdb_index) 181 | for key,val in indexer_lmdb: 182 | print(key,val) 183 | break 184 | try: 185 | indexer_lmdb['aaa'] #should rase KeyError 186 | except KeyError as e: 187 | print(e) 188 | indexer_lmdb['00000097-e849-4535-b270-fa6dd6a2ec83'] 189 | 190 | #test shelve database 191 | indexer_shelve = Indexer(backend = "shelve") 192 | indexer_shelve.load(shelve_index) 193 | for key,val in indexer_shelve: 194 | print(key,val) 195 | break 196 | try: 197 | indexer_shelve['aaa'] #should rase KeyError 198 | except KeyError as e: 199 | print(e) 200 | indexer_shelve['00000097-e849-4535-b270-fa6dd6a2ec83'] -------------------------------------------------------------------------------- /xron/utils/vq.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Tue Dec 28 13:33:50 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import torch 7 | 8 | class NearstEmbedding(torch.autograd.Function): 9 | """Get the nearest embedding of given input, used in VQ-VAE: 10 | https://arxiv.org/pdf/1711.00937.pdf 11 | """ 12 | @staticmethod 13 | def forward(ctx,input,embd_weight): 14 | """ 15 | Find the nearest neighbour of the embedding variable. 16 | 17 | Parameters 18 | ---------- 19 | ctx : Object. 20 | A object to stash information for backward computation, save 21 | arbitrary objects for use in backward computation by ctx.cave_for_backward method. 22 | input : torch.tensor 23 | The input tensor with size (*,embedding_dim) 24 | embd_weight : torch.tensor 25 | The embedding weight tensor with size (num_embeddings,embedding_dim) 26 | 27 | Returns 28 | ------- 29 | The nearest embedding vector. 30 | 31 | """ 32 | index = vq_idx(input,embd_weight) 33 | ctx.mark_non_differentiable(index) 34 | return torch.index_select(embd_weight,dim = 0,index = index).view_as(input) 35 | 36 | @staticmethod 37 | def backward(ctx,grad_post): 38 | """ 39 | Propgate the gradient from posterior layer back to the previous layer, 40 | dL/dq = dL/de. 41 | """ 42 | grad_input, grad_embd = None,None 43 | if ctx.needs_input_grad[0]: 44 | grad_input = grad_post.clone() 45 | return grad_input,grad_embd 46 | 47 | class NearstEmbeddingIndex(torch.autograd.Function): 48 | """Get the indexs of the nearest embedding of given input, no gradient will 49 | be propagate back. 50 | """ 51 | @staticmethod 52 | def forward(ctx,input,embd_weight): 53 | """ 54 | Find the nearest neighbour of the embedding variable. 55 | 56 | Parameters 57 | ---------- 58 | ctx : Object. 59 | A object to stash information for backward computation, save 60 | arbitrary objects for use in backward computation by ctx.cave_for_backward method. 61 | input : torch.tensor 62 | The input tensor with size (*,embedding_dim) 63 | weight : torch.tensor 64 | The embedding weight tensor with size (num_embeddings,embedding_dim) 65 | 66 | Returns 67 | ------- 68 | The nearest embedding vector. 69 | 70 | """ 71 | with torch.no_grad(): 72 | if input.shape[-1] != embd_weight.shape[-1]: 73 | raise ValueError("Input tensor shape %d is not consistent with the embedding %d."%(input.shape[-1],embd_weight.shape[-1])) 74 | distance = torch.cdist(input.flatten(end_dim = -2),embd_weight) 75 | index = torch.argmin(distance,dim = 1) 76 | ctx.mark_non_differentiable(index) 77 | return index 78 | 79 | @staticmethod 80 | def backward(ctx,grad_post): 81 | """ 82 | No gradient since this is for training embedding. 83 | """ 84 | raise RuntimeError('Trying to call `.grad()` on graph containing ' 85 | '`NearstEmbeddingIndex`. The function `NearstEmbeddingIndex` ' 86 | 'is not differentiable. Use `NearstEmbedding` ' 87 | 'if you want a straight-through estimator of the gradient.') 88 | 89 | def vq_idx(x,embd_weight): 90 | return NearstEmbeddingIndex().apply(x,embd_weight) 91 | 92 | def vq(x:torch.Tensor, embd_weight:torch.Tensor): 93 | """ 94 | Apply vector quantised to the given input x and a coding book embed_weight 95 | 96 | Parameters 97 | ---------- 98 | x : torch.Tensor 99 | The input tensor with shape [*,C], where C is the size of embedding. 100 | embd_weight : torch.Tensor 101 | The coding book of the embedding with shape [N,C] where N is the number 102 | of embeddings and C is the size of embedding. 103 | 104 | Returns 105 | ------- 106 | e : torch.Tensor 107 | The quantised embedding with same shape as input x. 108 | e_shadow : torch.Tensor 109 | The shadow tensor of e with same shape as input x. 110 | 111 | """ 112 | idx = NearstEmbeddingIndex().apply(x,embd_weight.detach()) 113 | e_shadow = torch.index_select(embd_weight,dim = 0,index = idx).view_as(x) 114 | e = NearstEmbedding().apply(x, embd_weight.detach()) 115 | return e,e_shadow 116 | 117 | def loss(x,embd_weight): 118 | e,e_shadow = vq(x,embd_weight) 119 | mse_loss = torch.nn.MSELoss(reduction = "mean") 120 | sg_q = x.detach() 121 | sg_e = e.detach() 122 | rc_signal = e 123 | rc_loss = mse_loss(rc_signal,x) 124 | embedding_loss = mse_loss(sg_q,e_shadow) 125 | commitment_loss = mse_loss(sg_e,x) 126 | return rc_loss, embedding_loss, commitment_loss 127 | 128 | if __name__ == "__main__": 129 | import torchviz 130 | from matplotlib import pyplot as plt 131 | from matplotlib import image 132 | print("Run testing code for Vector-quantization.") 133 | N_EMBD = 10 134 | EMBD_DIM = 3 135 | torch.manual_seed(1992) 136 | embd = torch.nn.Embedding(N_EMBD,EMBD_DIM) 137 | embd.weight.data.uniform_(-1./N_EMBD, 1./N_EMBD) 138 | embd_weight = embd.weight 139 | x = torch.rand((2,2,EMBD_DIM),requires_grad=True, dtype = torch.float) 140 | 141 | idx = NearstEmbeddingIndex.apply(x,embd_weight) 142 | rc_loss, embedding_loss, commitment_loss= loss(x,embd_weight) 143 | grad_x,grad_embd = torch.autograd.grad((rc_loss, embedding_loss, commitment_loss), (x,embd_weight), create_graph=True) 144 | torchviz.make_dot((grad_x,grad_embd, x, embd_weight, rc_loss, embedding_loss, commitment_loss), 145 | params={"grad_x": grad_x, 146 | "grad_embd":grad_embd, 147 | "x": x, 148 | "embedding_weight":embd_weight, 149 | "rc_loss": rc_loss, 150 | "embedding_loss":embedding_loss, 151 | "commitment_loss":commitment_loss}).render("./all_loss", format="png") 152 | 153 | e = NearstEmbedding().apply(x,embd_weight) 154 | nn_grad ,_ = torch.autograd.grad((e.sum()), (x,embd_weight), create_graph=True,allow_unused=True ) 155 | torchviz.make_dot((nn_grad, x, embd_weight,e), params={"grad_x": nn_grad, "x": x, "embedding_weight":embd_weight,"e":e}).render("./nearestembd", format="png") 156 | img = image.imread("./nearestembd.png") 157 | imgplot = plt.imshow(img) -------------------------------------------------------------------------------- /xron/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Haotian Teng 3 | """ 4 | import torch 5 | from torch import nn 6 | from functools import partial 7 | from typing import List 8 | 9 | 10 | Conv1dk1 = partial(nn.Conv1d,kernel_size = 1) 11 | RevConv1dk1 = partial(nn.ConvTranspose1d, kernel_size = 1) 12 | 13 | class AttentionNormalize(nn.Module): 14 | ##TODO: this module makes the model hard to converge, need to modify. 15 | def __init__(self, 16 | in_channels: int, 17 | out_channels: int = 5, 18 | activation: nn.Module = nn.SiLU): 19 | super().__init__() 20 | hidden_num = out_channels 21 | self.hidden_num = hidden_num 22 | self.self_map = nn.Conv1d(in_channels, 23 | hidden_num, 24 | stride = 1, 25 | kernel_size = 1) 26 | self.ReLU = nn.ReLU() 27 | self.wmf = torch.nn.Linear(in_channels,hidden_num) 28 | self.wmb = torch.nn.Linear(in_channels,hidden_num) 29 | self.wk = torch.nn.Linear(hidden_num,hidden_num) 30 | 31 | def forward(self, x: torch.Tensor) -> torch.Tensor: 32 | x0 = self.self_map(x) #[N,C,L] 33 | out = self.wmf(torch.mean(x,dim=2,keepdim=False))#[N,C] 34 | out = out.unsqueeze(dim = 1)*self.wmb(x.permute(0,2,1)) + x0.permute(0,2,1) #[N,L,C] 35 | k = self.wk(out) 36 | scale = torch.mean(torch.sum(k*out,dim = 2),dim = 1) #[N] 37 | scale = self.ReLU(scale) 38 | scale = (0.2+scale)/torch.sqrt(1.2+scale)/torch.sqrt(torch.tensor(self.hidden_num)) 39 | out = out*scale[:,None,None] 40 | # out = self.norm(out)#[N,L,C] 41 | return out.permute(0,2,1) #[N,C,L] 42 | 43 | class Res1d(nn.Module): 44 | def __init__(self, 45 | in_channels: int, 46 | out_channels: int, 47 | kernel_size: int = 3, 48 | stride: int = 1, 49 | activation: nn.Module = nn.SiLU, 50 | batch_norm: nn.Module = nn.BatchNorm1d): 51 | super().__init__() 52 | self.self_map = nn.Conv1d(in_channels, 53 | out_channels, 54 | stride = stride, 55 | kernel_size = stride) 56 | 57 | self.conv1 = nn.Conv1d(in_channels, 58 | out_channels, 59 | stride=1, 60 | kernel_size=1) 61 | self.bn1 = batch_norm(out_channels) 62 | self.conv2 = nn.Conv1d(out_channels, 63 | out_channels, 64 | kernel_size = kernel_size, 65 | stride = stride, 66 | padding = (kernel_size-stride)//2) 67 | self.bn2 = batch_norm(out_channels) 68 | self.activation = activation(inplace = True) 69 | 70 | def forward(self, x: torch.Tensor) -> torch.Tensor: 71 | x0 = self.self_map(x) #[N,C,L] 72 | out = self.conv1(x) #[N,C,L] 73 | dtype = x.dtype 74 | if isinstance(self.bn1,nn.LayerNorm): 75 | # Notice the use of nn.LayerNorm will make bfloat16 training/inference unstable due to this issue: 76 | # https://github.com/pytorch/pytorch/issues/66707 77 | # Autocast will make the output from LayerNorm to float32, which will cause 78 | # problem in downstream rnn operations, need to manually change it back in the 79 | # last CNN layer output (potentiall every layer of CNN cause other wise the 80 | # Resnet layer will calculate in FP32) 81 | # Consider using Apex Fused LayerNorm for bfloat16 training: 82 | # https://nvidia.github.io/apex/layernorm.html 83 | out = out.to(torch.float32) 84 | out = self.bn1(out.permute(0,2,1).contiguous()).permute(0,2,1).contiguous() #[N,C,L] 85 | out = out.to(dtype) 86 | else: 87 | out = self.bn1(out) 88 | out = self.activation(out) 89 | out = self.conv2(out) 90 | if isinstance(self.bn2,nn.LayerNorm): 91 | out = out.to(torch.float32) 92 | out = self.bn2(out.permute(0,2,1).contiguous()).permute(0,2,1).contiguous() 93 | out = out.to(dtype) 94 | else: 95 | out = self.bn2(out) 96 | out = self.activation(out) 97 | return out + x0 98 | 99 | class RevRes1d(nn.Module): 100 | def __init__(self, 101 | in_channels: int, 102 | out_channels: int, 103 | kernel_size: int = 3, 104 | stride: int = 1, 105 | activation: nn.Module = nn.SiLU, 106 | batch_norm: nn.Module = nn.BatchNorm1d): 107 | super().__init__() 108 | self.self_map = nn.ConvTranspose1d(in_channels, 109 | out_channels, 110 | kernel_size = stride, 111 | stride = stride) 112 | self.conv1 = Conv1dk1(in_channels, out_channels,stride=1) 113 | self.bn1 = batch_norm(out_channels) 114 | self.conv2 = nn.ConvTranspose1d(out_channels, 115 | out_channels, 116 | kernel_size = kernel_size, 117 | stride = stride, 118 | padding = (kernel_size-stride)//2) 119 | self.bn2 = batch_norm(out_channels) 120 | self.activation = activation(inplace = True) 121 | 122 | def forward(self, x: torch.Tensor) -> torch.Tensor: 123 | x0 = self.self_map(x) 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.activation(out) 127 | out = self.conv2(out) 128 | out = self.bn2(out) 129 | out = self.activation(out) 130 | 131 | out += x0 132 | return out 133 | 134 | class BidirectionalRNN(nn.Module): 135 | def __init__(self, 136 | input_size: int, 137 | hidden_size: int, 138 | num_layers: int, 139 | cell: nn.Module = nn.LSTM): 140 | super(BidirectionalRNN,self).__init__() 141 | self.rnn = cell(input_size,hidden_size,num_layers,bidirectional = True) 142 | self.num_layers = num_layers 143 | self.hidden_size = hidden_size 144 | 145 | def forward(self, x: torch.Tensor) -> torch.Tensor: 146 | output,_ = self.rnn(x) 147 | return output 148 | 149 | class Permute(nn.Module): 150 | def __init__(self,perm:List): 151 | super().__init__() 152 | self.perm = perm 153 | def forward(self,x:torch.Tensor): 154 | return x.permute(self.perm).contiguous() 155 | -------------------------------------------------------------------------------- /xron/nrhmm/method_illustration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Wed May 11 02:05:35 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | import time 8 | import toml 9 | import torch 10 | import itertools 11 | import numpy as np 12 | from matplotlib import pyplot as plt 13 | from xron.nrhmm.hmm import GaussianEmissions, RHMM 14 | from xron.nrhmm.hmm_input import Kmer2Transition, Kmer_Dataset 15 | from torchvision import transforms 16 | from torch.utils.data.dataloader import DataLoader 17 | from xron.xron_train_base import DeviceDataLoader 18 | 19 | def kmers2seq(kmers,idx2kmer): 20 | merged = [g for g,_ in itertools.groupby(kmers)] 21 | seqs = [idx2kmer[x][2] for x in merged] 22 | return ''.join(seqs) 23 | 24 | SCRATCH = os.environ['SCRATCH'] 25 | out_f = f"{SCRATCH}/Xron_Project/figures" 26 | torch.manual_seed(1992) 27 | model = torch.load(f"{SCRATCH}/NRHMM_models/xron_rhmm_models_new/ckpt-36234") 28 | emission = GaussianEmissions(model['hmm']['emission.means'].cpu().numpy(), 1*np.ones(3125)[:,None]) 29 | class TestArguments: 30 | input = f"{SCRATCH}/NA12878_RNA_IVT/xron_partial/extracted_kmers/" 31 | # input = "/home/heavens/bridge_scratch/m6A_Nanopore_RNA002/data/m6A_90_pct/20210430_1745_X2_FAQ15454_23428362/kmers_guppy_4000/" 32 | batch_size = 10 33 | device = "cuda" 34 | args = TestArguments 35 | hmm = RHMM(emission,normalize_transition=False,device = args.device) 36 | config = toml.load(os.path.join(args.input,"config.toml")) 37 | print("Readout the pore model.") 38 | chunks = np.load(os.path.join(args.input,"chunks.npy")) 39 | n_samples, sig_len = chunks.shape 40 | durations = np.load(os.path.join(args.input,"durations.npy")) 41 | idx2kmer = config['idx2kmer'] 42 | kmers = np.load(os.path.join(args.input,"kmers.npy")) 43 | base_prior = {x:1 for x in config['alphabeta']} 44 | k2t = Kmer2Transition(alphabeta = config['alphabeta'], 45 | k = config['k'], 46 | T_max = config['chunk_len'], 47 | kmer2idx = config['kmer2idx_dict'], 48 | idx2kmer = config['idx2kmer'], 49 | neighbour_kmer = 2, 50 | base_alternation = {"A":"M"}, 51 | base_prior = base_prior, 52 | kmer_replacement = True) 53 | dataset = Kmer_Dataset(chunks, durations, kmers,transform=transforms.Compose([k2t])) 54 | loader = DataLoader(dataset,batch_size = args.batch_size, shuffle = True) 55 | loader = DeviceDataLoader(loader,device = args.device) 56 | for i_batch, batch in enumerate(loader): 57 | signal_batch = batch['signal'] 58 | duration_batch = batch['duration'] 59 | transition_batch = batch['labels'] 60 | kmers_batch = batch['kmers'] 61 | Ls = duration_batch.cpu().numpy() 62 | break 63 | idx = 6 64 | signal = signal_batch[idx].detach().cpu().numpy()[:400] 65 | kmers = kmers_batch[idx].detach().cpu().numpy()[:400] 66 | seq = kmers2seq(kmers,idx2kmer) 67 | 68 | ## Print the alignment 69 | fig,axs = plt.subplots(nrows = 2, 70 | ncols = 2, 71 | gridspec_kw={'width_ratios': [1, 10], 72 | "height_ratios":[1,5], 73 | 'wspace':0, 74 | 'hspace':0}, 75 | figsize = (6,6)) 76 | for i, ax in enumerate(fig.axes): 77 | ax.set_xticklabels([]) 78 | ax.set_yticklabels([]) 79 | axs[0][0].set_axis_off() 80 | axs[0][1].plot(signal) 81 | axs[0][1].xaxis.set_visible(False) 82 | for i,c in enumerate(seq): 83 | axs[1][0].text(x = 0.4,y = i/len(seq)+0.01,s = c, fontsize = 10) 84 | alignment = [] 85 | i=0 86 | for x,g in itertools.groupby(kmers): 87 | alignment += [i]*len(list(g)) 88 | i+=1 89 | alignment = np.asarray(alignment) 90 | axs[1][1].plot(alignment[-1] - alignment,color = "black") 91 | axs[1][1].set_ylim(ymin = -0.5,ymax = alignment[-1]+0.5) 92 | fig.savefig(f"{out_f}/alignment1.png",dpi = 300, bbox_inches = "tight") 93 | # save pdf 94 | fig.savefig(f"{out_f}/alignment1.pdf",dpi = 300, bbox_inches = "tight",format = "pdf") 95 | #save eps 96 | fig.savefig(f"{out_f}/alignment1.eps",dpi = 300, bbox_inches = "tight",format = "eps") 97 | 98 | ## Print the Marcus suggestion 99 | fig,axs = plt.subplots(nrows = 2, 100 | ncols = 2, 101 | gridspec_kw={'width_ratios': [1, 10], 102 | "height_ratios":[1,5], 103 | 'wspace':0, 104 | 'hspace':0}, 105 | figsize = (6,6)) 106 | for i, ax in enumerate(fig.axes): 107 | ax.set_xticklabels([]) 108 | ax.set_yticklabels([]) 109 | axs[0][0].set_axis_off() 110 | axs[0][1].plot(signal) 111 | axs[0][1].xaxis.set_visible(False) 112 | for i,c in enumerate(seq): 113 | axs[1][0].text(x = 0.4,y = i/len(seq)+0.01,s = c, fontsize = 10) 114 | alignment = [] 115 | i=0 116 | for x,g in itertools.groupby(kmers): 117 | alignment += [i]*len(list(g)) 118 | i+=1 119 | alignment = np.asarray(alignment) 120 | axs[1][1].plot(alignment[-1] - alignment,color = "black") 121 | axs[1][1].set_ylim(ymin = -0.5,ymax = alignment[-1]+0.5) 122 | axs[1][1].fill_between(x = np.arange(len(signal)), 123 | y1 = -0.5, 124 | y2 = alignment[-1]+0.5, 125 | color = "grey") 126 | fig.savefig(f"{out_f}/alignment2.png",dpi = 300, bbox_inches = "tight") 127 | # save pdf 128 | fig.savefig(f"{out_f}/alignment2.pdf",dpi = 300, bbox_inches = "tight",format = "pdf") 129 | #save eps 130 | fig.savefig(f"{out_f}/alignment2.eps",dpi = 300, bbox_inches = "tight",format = "eps") 131 | 132 | ## Print the Marcus suggestion 133 | fig,axs = plt.subplots(nrows = 2, 134 | ncols = 2, 135 | gridspec_kw={'width_ratios': [1, 10], 136 | "height_ratios":[1,5], 137 | 'wspace':0, 138 | 'hspace':0}, 139 | figsize = (6,6)) 140 | for i, ax in enumerate(fig.axes): 141 | ax.set_xticklabels([]) 142 | ax.set_yticklabels([]) 143 | axs[0][0].set_axis_off() 144 | axs[0][1].plot(signal) 145 | axs[0][1].xaxis.set_visible(False) 146 | for i,c in enumerate(seq): 147 | axs[1][0].text(x = 0.4,y = i/len(seq)+0.01,s = c, fontsize = 10) 148 | alignment = [] 149 | i=0 150 | for x,g in itertools.groupby(kmers): 151 | alignment += [i]*len(list(g)) 152 | i+=1 153 | alignment = np.asarray(alignment) 154 | alignment = alignment[-1] - alignment 155 | axs[1][1].plot(alignment,color = "black") 156 | axs[1][1].fill_between(x = np.arange(len(signal)), 157 | y1 = alignment - 3, 158 | y2 = alignment + 3, 159 | color = "grey") 160 | axs[1][1].set_ylim(ymin = -0.5,ymax = alignment[0]+0.5) 161 | fig.savefig(f"{out_f}/alignment3.png",dpi = 300, bbox_inches = "tight") 162 | # save pdf 163 | fig.savefig(f"{out_f}/alignment3.pdf",dpi = 300, bbox_inches = "tight",format = "pdf") 164 | #save eps 165 | fig.savefig(f"{out_f}/alignment3.eps",dpi = 300, bbox_inches = "tight",format = "eps") 166 | -------------------------------------------------------------------------------- /xron/utils/transfer_methylation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Tue Aug 3 23:24:39 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import numpy as np 7 | import os 8 | import sys 9 | import pandas as pd 10 | import argparse 11 | import itertools 12 | from typing import Tuple,Iterable 13 | def load_data(prefix:str,renorm = False,read_meta = False)->Tuple[np.array,Iterable,np.array,np.array,np.array]: 14 | """ 15 | Load data from the prefix folder 16 | 17 | Parameters 18 | ---------- 19 | prefix : str 20 | The folder contains the chunk data and meta information. 21 | renorm : bool 22 | If the renomalization chunks is loaded or the orignal chunks. 23 | 24 | Returns 25 | ------- 26 | chunk : np.array[N,L] 27 | Contains the N chunks which length L. 28 | seq : Iterable[N] 29 | Contains N sequences corresponding. 30 | seq_lens : TYPE 31 | DESCRIPTION. 32 | mms : TYPE 33 | DESCRIPTION. 34 | mms_full : TYPE 35 | DESCRIPTION. 36 | 37 | """ 38 | if renorm: 39 | chunk_f = os.path.join(prefix,"chunks_renorm.npy") 40 | else: 41 | chunk_f = os.path.join(prefix,"chunks.npy") 42 | seq_f = os.path.join(prefix,"seqs.npy") 43 | seq_lens_f = os.path.join(prefix,"seq_lens.npy") 44 | mm_f = os.path.join(prefix,"mm.npy") 45 | meta_f = os.path.join(prefix,"meta.csv") 46 | chunk = np.load(chunk_f,mmap_mode = "r") 47 | seq = np.load(seq_f,mmap_mode = "r") 48 | seq_lens = np.load(seq_lens_f,mmap_mode = "r") 49 | if os.path.exists(mm_f) and read_meta: 50 | mm = np.load(mm_f,allow_pickle=True) 51 | read_ids = mm[-1] 52 | metas = pd.read_csv(meta_f,delimiter = " ",header = None) 53 | metas = metas.dropna(how = 'all') 54 | ids = list(metas[1]) 55 | mask = [True if id in ids else False for id in read_ids] 56 | mad,med,meth,offsets,scales,read_ids = [list(itertools.compress(x,mask)) for x in mm] 57 | mms = np.asarray([mad,med]) 58 | chunk_count = [len(list(j)) for i,j in itertools.groupby(ids)] 59 | mms_full = np.repeat(mms,chunk_count,axis = 1) 60 | else: 61 | mms = None 62 | mms_full = None 63 | return chunk,seq,seq_lens,mms,mms_full 64 | 65 | def main(args): 66 | chunk_m,seq_m,seq_lens_m= [],[],[] 67 | print("Read control dataset.") 68 | if args.control: 69 | chunk_c,seq_c,seq_lens_c,mms_c,mms_full_c= load_data(args.control,args.renorm) 70 | else: 71 | chunk_c,seq_c,seq_lens_c,mms_c,mms_full_c = [],[],[],[],[] 72 | print("Read methylation dataset.") 73 | m_size = 0 74 | if args.meth: 75 | for i,m in enumerate(args.meth): 76 | chunk,seq,seq_lens,mms,mms_full = load_data(m,args.renorm,read_meta = args.shift) 77 | if args.base_replace: 78 | print("A -> M transfer.") 79 | seq = np.array([x.replace('A','M') for x in seq]) 80 | if args.shift: 81 | offset = np.mean(mms_c[1,:])-np.mean(mms[1,:]) 82 | offset /= mms_full[0,:] 83 | chunk += offset[:,None] 84 | chunk_m.append(chunk) 85 | seq_m.append(seq) 86 | seq_lens_m.append(seq_lens) 87 | m_size += chunk.shape[0] 88 | print(" Methylation dataset %d with shape:%d"%(i+1,chunk.shape[0])) 89 | print("Control dataset size:%d"%(chunk_c.shape[0])) 90 | print("Methylation datasets total size:%d"%(m_size)) 91 | if len(args.meth) == 1: 92 | chunk_m = chunk_m[0] 93 | seq_m = seq_m[0] 94 | seq_lens_m = seq_lens_m[0] 95 | else: 96 | chunk_m = np.concatenate(chunk_m,axis = 0) 97 | seq_m = np.concatenate(seq_m,axis = 0) 98 | seq_lens_m = np.concatenate(seq_lens_m,axis = 0) 99 | else: 100 | chunk_m,seq_m,seq_lens_m = [],[],[] 101 | print("Resize the dataset according to cm_ratio.") 102 | if args.cm_ratio: 103 | size_c = len(chunk_c) 104 | size_m = len(chunk_m) 105 | curr_ratio = size_c/float(size_m) 106 | if curr_ratio > args.cm_ratio: 107 | shrink_size = int(size_m*args.cm_ratio) 108 | chunk_c,seq_c,seq_lens_c = chunk_c[:shrink_size],seq_c[:shrink_size],seq_lens_c[:shrink_size] 109 | elif curr_ratio < args.cm_ratio: 110 | shrink_size = int(size_c/args.cm_ratio) 111 | chunk_m,seq_m,seq_lens_m = chunk_m[:shrink_size],seq_m[:shrink_size],seq_lens_m[:shrink_size] 112 | print("Merge control and methylation datasets.") 113 | if args.control and args.meth: 114 | chunk_all = np.concatenate([chunk_m,chunk_c],axis = 0) 115 | seq_all = np.concatenate([seq_m,seq_c],axis = 0) 116 | seq_lens_all = np.concatenate([seq_lens_m,seq_lens_c],axis = 0) 117 | elif args.control: 118 | chunk_all,seq_all,seq_lens_all = chunk_c,seq_c,seq_lens_c 119 | else: 120 | chunk_all,seq_all,seq_lens_all = chunk_m,seq_m,seq_lens_m 121 | out_f = args.output 122 | print("Save the merged dataset.") 123 | if args.cm_ratio: 124 | print("Final size: control - %d, methylation - %d, target size ratio %.1f, final size ratio %.1f"%(len(chunk_c),len(chunk_m),args.cm_ratio,len(chunk_c)/float(len(chunk_m)))) 125 | else: 126 | print("Final size: control - %d, methylation - %d"%(len(chunk_c),len(chunk_m))) 127 | os.makedirs(out_f,exist_ok = True) 128 | np.save(os.path.join(out_f,'chunks.npy'),chunk_all) 129 | np.save(os.path.join(out_f,'seqs.npy'),seq_all) 130 | np.save(os.path.join(out_f,'seq_lens.npy'),seq_lens_all) 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser( 134 | description='Training model with tfrecord file') 135 | parser.add_argument('-m', '--meth', default = None, 136 | help = "The methylation folders, multiple directories separate by comma.") 137 | parser.add_argument('-c', '--control', default = None, 138 | help = "The control folder.") 139 | parser.add_argument('-o', '--output', required = True, 140 | help = "The output folder of the merged dataset.") 141 | parser.add_argument('--shift',action = "store_true", dest = "shift", 142 | help = "If move the methylation signal according to the shift of the control median value.") 143 | parser.add_argument('--cm_ratio',default = None, type = float, 144 | help = "The size ratio of control/methylation, size will be adjusted to a maximum reads.") 145 | parser.add_argument('--no_base_replace', action = "store_false",dest = "base_replace", 146 | help = "If the bases in methylated dataset being transferred from A to M.") 147 | parser.add_argument('--renorm',action="store_true",dest="renorm", 148 | help = "Read the renormalization data instead of original data.") 149 | args = parser.parse_args(sys.argv[1:]) 150 | if not args.control and not args.meth: 151 | raise ValueError("Neither --control or --meth being specified.") 152 | if args.cm_ratio and (not args.control or not args.meth): 153 | raise ValueError("Require both control and methylation dataset being provided when cm_ratio is set.") 154 | args.meth = args.meth.split(',') 155 | print(args.base_replace) 156 | main(args) 157 | -------------------------------------------------------------------------------- /xron/xron_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Mon Mar 15 00:04:58 2021 5 | 6 | @author: haotian teng 7 | """ 8 | import os 9 | import sys 10 | import umap 11 | import torch 12 | import argparse 13 | import numpy as np 14 | import torch.utils.data as data 15 | from itertools import islice 16 | from typing import List,Union 17 | from torchvision import transforms 18 | from matplotlib import pyplot as plt 19 | from xron.xron_input import Dataset, ToTensor 20 | from xron.xron_train_base import Trainer, load_config, DeviceDataLoader 21 | from xron.xron_model import REVCNN,DECODER_CONFIG,CRNN,MM 22 | from torch.distributions.one_hot_categorical import OneHotCategorical as OHC 23 | 24 | class Evaluator(Trainer): 25 | def __init__(self, 26 | encoder:CRNN, 27 | decoders:List[Union[REVCNN,MM]], 28 | config:DECODER_CONFIG, 29 | device:str = None): 30 | device = config.EVALUATION['device'] 31 | super().__init__(train_dataloader=None, 32 | nets = {"encoder":encoder, 33 | "decoder":decoders[0], 34 | "mm":decoders[1]}, 35 | config = config, 36 | device = device, 37 | eval_dataloader = None) 38 | self.encoder = encoder 39 | self.decoder_revcnn, self.decoder_mm = decoders 40 | 41 | def eval_once(self,batch:np.ndarray): 42 | encoder = self.encoder 43 | d1 = self.decoder_revcnn 44 | d2 = self.decoder_mm 45 | signal = batch['signal'] 46 | logprob = encoder.forward(signal) #[L,N,C] 47 | prob = torch.exp(logprob) 48 | m = OHC(prob) 49 | sampling = m.sample().permute([1,2,0]) #[L,N,C]->[N,C,L] 50 | rc_signal = d1.forward(sampling).permute([0,2,1]) #[N,L,C] -> [N,C,L] 51 | if d2: 52 | rc_signal += d2.forward(sampling,device = self.device).permute([0,2,1]) 53 | predictions = encoder.ctc_decode(logprob, 54 | alphabet = 'N' + self.config.CTC['alphabeta'], 55 | beam_size = self.config.CTC['beam_size'], 56 | beam_cut_threshold = self.config.CTC['beam_cut_threshold']) 57 | return rc_signal,prob, predictions,sampling 58 | 59 | class VQ_Evaluator(Trainer): 60 | def __init__(self, 61 | encoder:CRNN, 62 | decoders:List[Union[REVCNN,MM]], 63 | config:DECODER_CONFIG, 64 | device:str = None): 65 | device = config.EVALUATION['device'] 66 | super().__init__(train_dataloader=None, 67 | nets = {"encoder":encoder, 68 | "decoder":decoders[0], 69 | "mm":decoders[1]}, 70 | config = config, 71 | device = device, 72 | eval_dataloader = None) 73 | self.encoder = encoder 74 | self.decoder_revcnn, self.decoder_mm = decoders 75 | self.umap_transformer = umap.UMAP() 76 | 77 | def eval_once(self,batch:np.ndarray,umap_visualize = True): 78 | encoder = self.encoder 79 | d1 = self.decoder_revcnn 80 | d2 = self.decoder_mm 81 | # embedding = d2.level_embedding 82 | signal = batch['signal'] 83 | embed = encoder.forward_wo_fnn(signal) #[L,N,C] 84 | if umap_visualize: 85 | u = self.umap_embedding(embed) 86 | rc_signal = d1.forward(embed).permute([0,2,1]) #[N,L,C] -> [N,C,L] 87 | return rc_signal,u 88 | 89 | def umap_embedding(self,embedding:np.ndarray): 90 | u = self.umap_transformer.fit_transform(embedding.view(-1,embedding.shape[-1]).detach().cpu().numpy()) 91 | return u 92 | 93 | 94 | def cmd_args(): 95 | parser = argparse.ArgumentParser( 96 | description='Training model with tfrecord file') 97 | parser.add_argument('-i', '--chunks', required = True, 98 | help = "The .npy file contain chunks.") 99 | parser.add_argument('--model_folder', required = True, 100 | help = "The folder contains the trained model.") 101 | parser.add_argument('--seq', default = None, 102 | help="The .npy file contain the sequence.") 103 | parser.add_argument('--seq_len', default = None, 104 | help="The .npy file contain the sueqnece length.") 105 | parser.add_argument('--device', default = 'cuda', 106 | help="The device used for training, can be cpu or cuda.") 107 | parser.add_argument('--repeat', type = int, default = 5, 108 | help="The repeat used to test.") 109 | parser.add_argument('--method',default = "VQ", 110 | help="The embedding method used to train, can be VQ or MM") 111 | args = parser.parse_args(sys.argv[1:]) 112 | return args 113 | 114 | if __name__ == "__main__": 115 | args = cmd_args() 116 | 117 | #Load model 118 | print("Load model.") 119 | config = load_config(os.path.join(args.model_folder,'config.toml')) 120 | stride = config.CNN['Layers'][-1]['stride'] 121 | config.EVALUATION = {"batch_size":100, 122 | "device":args.device} 123 | encoder = CRNN(config) 124 | revcnn = REVCNN(config) if 'CNN_DECODER' in config.__dict__.keys() else None 125 | mm = MM(config) if 'PORE_MODEL' in config.__dict__.keys() else None 126 | if args.method == "VQ": 127 | e = VQ_Evaluator(encoder,[revcnn,mm],config,device = args.device) 128 | elif args.method == "MM": 129 | e = Evaluator(encoder,[revcnn,mm],config,device = args.device) 130 | e.load(args.model_folder) 131 | 132 | #Load data 133 | print("Load data.") 134 | chunks = np.load(args.chunks) 135 | if args.seq: 136 | reference = np.load(args.seq) 137 | else: 138 | reference = None 139 | if args.seq_len: 140 | ref_len = np.load(args.seq_len) 141 | else: 142 | ref_len = None 143 | dataset = Dataset(chunks,seq = reference,seq_len = ref_len,transform = transforms.Compose([ToTensor()])) 144 | loader = data.DataLoader(dataset,batch_size = config.EVALUATION["batch_size"],shuffle = True, num_workers = 4) 145 | DEVICE = args.device 146 | loader = DeviceDataLoader(loader,device = DEVICE) 147 | 148 | #Evaluation 149 | batch = next(islice(loader,2,None)) 150 | rc_signal, umap_vis = e.eval_once(batch) 151 | rc_signal = rc_signal.detach().cpu().numpy() 152 | norm_signal = (rc_signal - np.mean(rc_signal,axis = 2))/np.std(rc_signal,axis = 2) 153 | 154 | #Plot 155 | for r in np.arange(args.repeat): 156 | idx = np.random.randint(low = 0, high = config.EVALUATION['batch_size']-1) 157 | fig,axs = plt.subplots(nrows = 2,figsize = (20,30),gridspec_kw={'height_ratios': [1, 2]}) 158 | start_idx =0 159 | last_idx = 800 160 | axs[0].plot(norm_signal[idx,0,start_idx:last_idx],label = "Reconstruction") 161 | axs[0].plot(batch['signal'].cpu()[idx,0,start_idx:last_idx],label = "Original signal") 162 | axs[0].legend() 163 | axs[1].scatter(umap_vis[:,0],umap_vis[:,1]) 164 | fig.savefig(os.path.join(args.model_folder,'reconstruction_%d.png'%(r))) 165 | -------------------------------------------------------------------------------- /xron/utils/sparse_op.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Jul 1 04:38:02 2022 5 | 6 | @author: haotian teng 7 | """ 8 | import torch 9 | import itertools 10 | import numpy as np 11 | 12 | def log_domain_sparse_matmul(A:torch.sparse_coo_tensor, log_B:torch.Tensor, dim:int = -1): 13 | """ 14 | Do a sparse-dense tensor multiplication and reduced on the given dimension. 15 | 16 | Parameters 17 | ---------- 18 | A : torch.sparse_coo_tensor 19 | A sparse tensor with shape mxnxp. 20 | log_B : torch.Tensor 21 | A dense tensor in the log domain with same shape, broadcast is supported on dense tensor. 22 | dim : int, optional 23 | The dimension to perform reduction on A. The default is -1. 24 | 25 | Returns 26 | ------- 27 | A sparse tensor. 28 | 29 | """ 30 | idxs,update = log_domain_sparse_product(A,log_B) 31 | shape_A = torch.tensor(A.shape) 32 | remain_dims = np.delete(np.arange(idxs.shape[0]),dim) 33 | remain_idxs = idxs[remain_dims,:].T.tolist() 34 | key_func = lambda x: x[1] 35 | update = sorted(zip(update,remain_idxs),key = key_func) 36 | nested = [ (k,list(g)) for k,g in itertools.groupby(update,key = key_func)] 37 | nested_vals = [[y[0] for y in x[1]] for x in nested] 38 | nested_idxs = [x[0] for x in nested] 39 | max_cols = max([len(x) for x in nested_vals]) 40 | padded = torch.tensor([x + [-np.inf]*(max_cols - len(x)) for x in nested_vals],device = A.device) 41 | return torch.sparse_coo_tensor(indices = list(zip(*nested_idxs)), 42 | values = torch.logsumexp(padded,dim = 1), 43 | size = shape_A[remain_dims].tolist(), 44 | device = A.device) 45 | 46 | def log_domain_sparse_matmul_new(A:torch.sparse_coo_tensor, log_B:torch.Tensor, dim:int = -1): 47 | """ 48 | Do a sparse-dense tensor multiplication and reduced on the given dimension. 49 | 50 | Parameters 51 | ---------- 52 | A : torch.sparse_coo_tensor 53 | A sparse tensor with shape mxnxp. 54 | log_B : torch.Tensor 55 | A dense tensor in the log domain with same shape, broadcast is supported on dense tensor. 56 | dim : int, optional 57 | The dimension to perform reduction on A. The default is -1. 58 | 59 | Returns 60 | ------- 61 | A sparse tensor. 62 | 63 | """ 64 | idxs,update = log_domain_sparse_product(A,log_B) 65 | shape_A = torch.tensor(A.shape) 66 | remain_dims = np.delete(np.arange(idxs.shape[0]),dim) 67 | remain_idxs = idxs[remain_dims,:].T 68 | uniq,uniq_idxs = torch.unique(remain_idxs,dim = 0, return_inverse = True) 69 | nested_vals = [update[uniq_idxs == i].tolist() for i in np.arange(len(uniq))] 70 | # nested_vals = [1]*len(uniq) 71 | max_cols = max([len(x) for x in nested_vals]) 72 | padded = torch.tensor([x + [-np.inf]*(max_cols - len(x)) for x in nested_vals],device = A.device) 73 | return torch.sparse_coo_tensor(indices = uniq.T, 74 | values = torch.logsumexp(padded,dim = 1), 75 | size = shape_A[remain_dims].tolist(), 76 | device = A.device) 77 | 78 | 79 | def log_domain_sparse_max(A:torch.sparse_coo_tensor, log_B:torch.Tensor, dim:int = -1): 80 | """ 81 | Do a sparse-dense tensor multiplication and reduced on the given dimension. 82 | 83 | Parameters 84 | ---------- 85 | A : torch.sparse_coo_tensor 86 | A sparse tensor with shape mxnxp. 87 | log_B : torch.Tensor 88 | A dense tensor in the log domain with same shape, broadcast is supported on dense tensor. 89 | dim : int, optional 90 | The dimension to perform reduction on A. The default is -1. 91 | 92 | Returns 93 | ------- 94 | A sparse tensor. 95 | 96 | """ 97 | idxs,update = log_domain_sparse_product(A,log_B) 98 | n_dims_A = idxs.shape[0] 99 | shape_A = torch.tensor(A.shape) 100 | remain_dims =np.delete(np.arange(n_dims_A),dim) 101 | remain_idxs = list(zip(*[idxs[x].tolist() for x in remain_dims])) 102 | reduced_idxs = idxs[dim].tolist() 103 | update = update.tolist() 104 | key_func = lambda x: x[2] 105 | update = sorted(zip(update,reduced_idxs,remain_idxs),key = key_func) 106 | nested = [ (k,list(g)) for k,g in itertools.groupby(update,key = key_func)] 107 | nested_vals = [[y[0] for y in x[1]] for x in nested] 108 | nested_idxs = list(zip(*[x[0] for x in nested])) 109 | nested_reduced_idxs = [[y[1] for y in x[1]] for x in nested] 110 | max_cols = max([len(x) for x in nested_vals]) 111 | padded = torch.tensor([x + [-np.inf]*(max_cols - len(x)) for x in nested_vals],device = A.device) 112 | padded_reduced_idxs = torch.tensor([x + [-1]*(max_cols - len(x)) for x in nested_reduced_idxs],device = A.device,dtype = torch.long) 113 | result,argmax = torch.max(padded,dim = 1) 114 | argmax = torch.gather(padded_reduced_idxs,dim = 1,index = argmax[:,None]).squeeze(dim = 1) 115 | max_idx = torch.zeros(shape_A[remain_dims].tolist(),device = A.device,dtype = torch.long) 116 | max_idx[tuple(nested_idxs)] = argmax 117 | return torch.sparse_coo_tensor(indices = nested_idxs, 118 | values = result, 119 | size = shape_A[remain_dims].tolist(), 120 | device = A.device).to_dense(),max_idx 121 | 122 | 123 | def log_domain_sparse_product(A:torch.sparse_coo_tensor, log_B:torch.Tensor): 124 | """ 125 | Do a sparse-dense tensor production. 126 | 127 | Parameters 128 | ---------- 129 | A : torch.sparse_coo_tensor 130 | A sparse tensor with shape mxnxp. 131 | log_B : torch.Tensor 132 | A dense tensor in the log domain with same shape, broadcast is supported on dense tensor. 133 | 134 | Returns 135 | ------- 136 | idxs,vals 137 | 138 | """ 139 | A = A.coalesce() 140 | idxs = A.indices() 141 | log_vals = A.values().log() 142 | shape_A = torch.tensor(A.shape) 143 | shape_B = torch.tensor(log_B.shape) 144 | n_dims_A = idxs.shape[0] 145 | n_dims_B = len(shape_B) 146 | assert n_dims_A == n_dims_B, "Tensor has different number of dimensions." 147 | assert torch.all(torch.logical_or(shape_A == shape_B,shape_B == 1)), "Shape mismatch, got {} and {}. Broadcast only supported on dense tensor.".format(shape_A,shape_B) 148 | idxs_B = idxs.clone() 149 | idxs_B[torch.where(shape_B==1)] = 0 150 | update = log_B[tuple(idxs_B)] + log_vals 151 | return idxs,update 152 | 153 | if __name__ == "__main__": 154 | from time import time 155 | N = 3000 156 | K = 3125 157 | B = 100 158 | i = torch.tensor([np.random.randint(low = 0, high = B, size = N), 159 | np.random.randint(low = 0, high = K, size = N), 160 | np.random.randint(low = 0, high = K, size = N)]) 161 | v = torch.tensor(np.random.rand(N), dtype=torch.float32) 162 | s = torch.sparse_coo_tensor(i, v, [B, K, K]) 163 | d = torch.rand((B,K,1)) 164 | start = time() 165 | result = log_domain_sparse_matmul(s,d,dim = 1) 166 | print(time()-start) 167 | start = time() 168 | result_new = log_domain_sparse_matmul_new(s,d,dim=1) 169 | print(start-time()) 170 | 171 | idxs = s.coalesce().indices() 172 | remain_dims = np.delete(np.arange(idxs.shape[0]),1) 173 | remain_idxs = idxs[remain_dims].T -------------------------------------------------------------------------------- /xron/xron_train_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Mar 11 16:07:12 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | import toml 8 | import torch 9 | import itertools 10 | import numpy as np 11 | from typing import Union,Dict 12 | from torch.utils.data.dataloader import DataLoader 13 | from xron.xron_model import CRNN, REVCNN, CONFIG,DECODER_CONFIG,MM 14 | 15 | 16 | class Trainer(object): 17 | def __init__(self, 18 | train_dataloader:DataLoader, 19 | nets:Dict[str,Union[CRNN,REVCNN,MM]], 20 | config:Union[CONFIG,DECODER_CONFIG], 21 | device:str = None, 22 | eval_dataloader:DataLoader = None): 23 | """ 24 | 25 | Parameters 26 | ---------- 27 | train_dataloader : DataLoader 28 | Training dataloader. 29 | nets : Dict[str,Union[CRNN,REVCNN]] 30 | A CRNN or REVCNN network instance. 31 | device: str 32 | The device used to train the model, can be 'cpu' or 'cuda'. 33 | Default is None, use cuda device if it's available. 34 | config: CONFIG 35 | A CONFIG class contains training configurations. Need to contain 36 | at least these parameters: keep_record, device and grad_norm. 37 | eval_dataloader : DataLoader, optional 38 | Evaluation dataloader, if None training dataloader will be used. 39 | The default is None. 40 | 41 | Returns 42 | ------- 43 | None. 44 | 45 | """ 46 | self.train_ds = train_dataloader 47 | self.device = self._get_device(device) 48 | if eval_dataloader is None: 49 | self.eval_ds = self.train_ds 50 | else: 51 | self.eval_ds = eval_dataloader 52 | self.nets = nets 53 | paras = [x.parameters() for x in self.nets.values()] 54 | self.parameters = itertools.chain(*paras) 55 | for net in self.nets.values(): 56 | net.to(self.device) 57 | self.global_step = 0 58 | self.save_list = [] 59 | self.keep_record = config.TRAIN['keep_record'] 60 | self.grad_norm = config.TRAIN['grad_norm'] 61 | self.config = config 62 | self.losses = [] 63 | self.errors = [] 64 | 65 | def reload_data(self,train_dataloader, eval_dataloader = None): 66 | self.train_ds = train_dataloader 67 | if eval_dataloader is None: 68 | self.eval_ds = train_dataloader 69 | else: 70 | self.eval_ds = eval_dataloader 71 | 72 | def _get_device(self,device): 73 | if device is None: 74 | if torch.cuda.is_available(): 75 | return torch.device('cuda') 76 | else: 77 | return torch.device('cpu') 78 | else: 79 | return torch.device(device) 80 | 81 | def _update_records(self): 82 | record_file = os.path.join(self.save_folder,'records.toml') 83 | with open(record_file,'w+') as f: 84 | toml.dump(self.records,f) 85 | 86 | def save(self): 87 | ckpt_file = os.path.join(self.save_folder,'checkpoint') 88 | current_ckpt = 'ckpt-'+str(self.global_step) 89 | model_file = os.path.join(self.save_folder,current_ckpt) 90 | self.save_list.append(current_ckpt) 91 | if not os.path.isdir(self.save_folder): 92 | os.mkdir(self.save_folder) 93 | if len(self.save_list) > self.keep_record: 94 | os.remove(os.path.join(self.save_folder,self.save_list[0])) 95 | self.save_list = self.save_list[1:] 96 | if os.path.isfile(model_file): 97 | os.remove(model_file) 98 | with open(ckpt_file,'w+') as f: 99 | f.write("latest checkpoint:" + current_ckpt + '\n') 100 | for path in self.save_list: 101 | f.write("checkpoint file:" + path + '\n') 102 | net_dict = {key:net.state_dict() for key,net in self.nets.items()} 103 | torch.save(net_dict,model_file) 104 | 105 | def save_loss(self): 106 | loss_file = os.path.join(self.save_folder,'losses.csv') 107 | error_file = os.path.join(self.save_folder,'errors.csv') 108 | if len(self.losses): 109 | with open(loss_file,'a+') as f: 110 | f.write('\n'.join([str(x) for x in self.losses])) 111 | f.write('\n') 112 | if len(self.errors): 113 | with open(error_file,'a+') as f: 114 | f.write('\n'.join([str(x) for x in self.errors])) 115 | f.write('\n') 116 | self.losses = [] 117 | self.errors = [] 118 | 119 | def _save_config(self): 120 | config_file = os.path.join(self.save_folder,'config.toml') 121 | config_modules = [x for x in self.config.__dir__() if not x .startswith('_')][::-1] 122 | config_dict = {x:getattr(self.config,x) for x in config_modules} 123 | with open(config_file,'w+') as f: 124 | toml.dump(config_dict,f) 125 | 126 | def load(self,save_folder,update_global_step = True): 127 | self.save_folder = save_folder 128 | ckpt_file = os.path.join(save_folder,'checkpoint') 129 | with open(ckpt_file,'r') as f: 130 | latest_ckpt = f.readline().strip().split(':')[1] 131 | if update_global_step: 132 | self.global_step = int(latest_ckpt.split('-')[1]) 133 | ckpt = torch.load(os.path.join(save_folder,latest_ckpt),map_location=self.device) 134 | for key,net in ckpt.items(): 135 | if key in self.nets.keys(): 136 | try: 137 | self.nets[key].load_state_dict(net,strict = True) 138 | except RuntimeError: 139 | print(f"Exact loading {key} failed, try load loosely.") 140 | self.nets[key].load_state_dict(net,strict = False) 141 | self.nets[key].to(self.device) 142 | else: 143 | print("%s net is defined in the checkpoint but is not imported because it's not defined in the model."%(key)) 144 | 145 | class DeviceDataLoader(): 146 | """Wrap a dataloader to move data to a device""" 147 | def __init__(self, dataloader, device = None): 148 | self.dataloader = dataloader 149 | if device is None: 150 | device = self.get_default_device() 151 | else: 152 | device = torch.device(device) 153 | self.device = device 154 | 155 | def __iter__(self): 156 | """Yield a batch of data after moving it to device""" 157 | for b in self.dataloader: 158 | yield self._to_device(b, self.device) 159 | 160 | def __len__(self): 161 | """Number of batches""" 162 | return len(self.dataloader) 163 | 164 | def _to_device(self,data,device): 165 | if isinstance(data, (list,tuple)): 166 | return [self._to_device(x,device) for x in data] 167 | if isinstance(data, (dict)): 168 | temp_dict = {} 169 | for key in data.keys(): 170 | temp_dict[key] = self._to_device(data[key],device) 171 | return temp_dict 172 | return data.to(device, non_blocking=True) 173 | 174 | def get_default_device(self): 175 | if torch.cuda.is_available(): 176 | return torch.device('cuda') 177 | else: 178 | return torch.device('cpu') 179 | 180 | def load_config(config_file): 181 | class CONFIG(object): 182 | pass 183 | with open(config_file,'r') as f: 184 | config_dict = toml.load(f) 185 | config = CONFIG() 186 | for k,v in config_dict.items(): 187 | setattr(config,k,v) 188 | return config -------------------------------------------------------------------------------- /xron/watch_training_progress.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Aug 30 15:27:21 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import os 7 | from typing import List 8 | import seaborn as sns 9 | import numpy as np 10 | from matplotlib import pyplot as plt 11 | 12 | def watch_errors(fs): 13 | errors = [] 14 | for f in fs: 15 | error_file = os.path.join(f,"errors.csv") 16 | e = [] 17 | with open(error_file,'r') as f: 18 | for line in f: 19 | try: 20 | error = np.float(line.strip()) 21 | except: 22 | print(line.strip()) 23 | if not np.isnan(error): 24 | e.append(error) 25 | errors.append(e) 26 | return errors 27 | 28 | def smooth(scalars: List[float], weight: float) -> List[float]: # Weight between 0 and 1 29 | last = scalars[0] # First value in the plot (first timestep) 30 | smoothed = list() 31 | for point in scalars: 32 | smoothed_val = last * weight + (1 - weight) * point # Calculate smoothed value 33 | smoothed.append(smoothed_val) # Save it 34 | last = smoothed_val # Anchor the last smoothed value 35 | 36 | return smoothed 37 | 38 | def std_plot(errors:list,axs,**plot_args): 39 | ls = [len(x) for x in errors] 40 | max_len = max(ls) 41 | errors = np.asarray([np.pad(x,(0,max_len-y),'constant',constant_values = np.nan) for x,y in zip(errors,ls)]) 42 | e_mean = np.nanmean(errors,axis = 0) 43 | e_std = np.nanstd(errors,axis = 0) 44 | x = np.arange(len(e_mean)) 45 | e_mean_smoothed = smooth(e_mean,plot_args['smooth']) 46 | axs.plot(x,e_mean_smoothed,c = plot_args['color'],linewidth = 1) 47 | axs.fill_between(x,e_mean - e_std,e_mean + e_std,facecolor = plot_args['color'],alpha = plot_args['alpha'],label = plot_args['label']) 48 | return axs 49 | 50 | if __name__ == "__main__": 51 | ### Plot multiple optimizers training result 52 | # prefix = "/home/heavens/bridge_scratch/Xron_models_control/xron_model_supervised_control_dataset_%d_%s_16G" 53 | # adam_fs = [prefix%(x,'Adam') for x in np.arange(4)] 54 | # sgd_fs = [prefix%(x,'SGD') for x in np.arange(4)] 55 | # momentum_fs = [prefix%(x,'Momentum') for x in np.arange(4,8,1)] 56 | # adagrad_fs = [prefix%(x,'Adagrad') for x in np.arange(4,8,1)] 57 | # colors = ['r','g','b','yellow'] 58 | # opts = ['Adam','SGD','Momentum','Adagrad'] 59 | # fs = [adam_fs,sgd_fs,momentum_fs,adagrad_fs] 60 | 61 | # Plot single training error 62 | # opts = ['2000-LN-Adam','2000-LN-Adagrad','4000-LN-Adam_1,3','4000-LN-Adam_0,2','4000-LN-Adagrad','2000-BN-Adam','2000-BN-Adagrad'] 63 | # opts = [str(x) for x in range(8)] 64 | # opts = [str(x) for x in range(4)] 65 | colors = sns.color_palette() 66 | 67 | ##Control plotting 68 | 69 | # fs = [['/home/heavens/bridge_scratch/Xron_models_merge_d/xron_model_supervised_merge_Adagrad_transfer_learning_fixD'], 70 | # ['/home/heavens/bridge_scratch/Xron_models_merge_d/xron_model_supervised_merge_Adagrad_transfer_learning_deviationCorrected'], 71 | # ['/home/heavens/bridge_scratch/Xron_models_DirectTrainOnMerge/xron_model_supervised_control_dataset_%d_Adagrad_16G'%(x) for x in np.arange(8)]] 72 | # fs = [['/home/heavens/bridge_scratch/Xron_models_merge_d/xron_model_supervised_merge_Adagrad_transfer_learning_deviationCorrected']] 73 | 74 | # ### Plot all 75 | # title = "Control model accuracy" 76 | # repeat = 4 77 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_2000L/xron_model_supervised_control_dataset_%d_Adam_16G'%(x) for x in np.arange(repeat)], 78 | # ['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_2000L/xron_model_supervised_control_dataset_%d_Adagrad_16G'%(x+4) for x in np.arange(repeat)], 79 | # ['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_1_Adam_16G', 80 | # '/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_3_Adam_16G'], 81 | # ['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_0_Adam_16G', 82 | # '/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_2_Adam_16G'], 83 | # ['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_%d_Adagrad_16G'%(x+4) for x in np.arange(repeat)], 84 | # ['/home/heavens/bridge_scratch/Xron_models_NAIVT/xron_model_supervised_control_dataset_%d_Adam_16G'%(x) for x in np.arange(repeat)], 85 | # ['/home/heavens/bridge_scratch/Xron_models_NAIVT/xron_model_supervised_control_dataset_%d_Adagrad_16G'%(x+4) for x in np.arange(repeat)]] 86 | # opts = ['2000L-Adam'] + ['2000L-Adagrad'] + ['4000L-Adam'] + ['4000L-Adagrad'] + ['Adam'] + ['Adagrad'] 87 | 88 | # ## Plot 8000L 89 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_8000L/xron_model_supervised_control_dataset_%d_16G'%(x)] for x in np.arange(8)] 90 | # opts = np.arange(8) 91 | # title = "8000L training accuracy" 92 | 93 | # ## Plot 4000L 94 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_%d_Adam_16G'%(x)] for x in np.arange(4)] 95 | # fs += [['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_%d_Adagrad_16G'%(x+4)] for x in np.arange(4)] 96 | # opts = ['Adam']*4 + ['Adagrad']*4 97 | 98 | # title = "4000L control dataset training curve" 99 | 100 | ## Plot 101 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_2000L/xron_model_supervised_control_dataset_%d_Adam_16G'%(x)] for x in np.arange(4)] 102 | # fs += [['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_2000L/xron_model_supervised_control_dataset_%d_Adagrad_16G'%(x+4)] for x in np.arange(4)] 103 | # opts = ['Adam']*4 + ['Adagrad']*4 104 | # title = "2000L control dataset training curve" 105 | 106 | ## Plot attention model 107 | repeat = 4 108 | fs =[['/home/heavens/bridge_scratch/Xron_models_attention_NAIVT/'], 109 | ['/home/heavens/bridge_scratch/Xron_models_NAIVT_LayerNorm_4000L/xron_model_supervised_control_dataset_1_Adam_16G']] 110 | opts = ['Attention'] + ['Adam'] 111 | title = "Attention model" 112 | 113 | ### Retrain on 100 Methylation 114 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT+100METH_LayerNorm_4000L/xron_model_%d_CM%d'%(x,x%4+1)] for x in np.arange(8)] 115 | # opts = ['Control-Methylation-Ratio %d:1, Adam'%(x+1) for x in np.arange(4)] + ['Control-Methylation-Ratio %d:1, AdamW'%(x+1) for x in np.arange(4)] 116 | # title = "Retrain only the last two layers" 117 | ### 118 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT+100METH_LayerNorm_4000L_retrain_all/xron_model_%d_CM%d'%(x,x%4+1)] for x in np.arange(8)] 119 | # opts = ['Control-Methylation-Ratio %d:1, Adam'%(x+1) for x in np.arange(4)] + ['Control-Methylation-Ratio %d:1, AdamW'%(x+1) for x in np.arange(4)] 120 | # title = "Retrain the whole NN." 121 | 122 | ### 123 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT+100METH+90METH_LayerNorm_4000L/xron_model_%d_CM%d'%(x,x//2%2)] for x in np.arange(8)] 124 | # lrs = ['1e-4','4e-5'] 125 | # opts = ['Control-Methylation-Ratio %d:1, LR:%s'%(x//2%2,lrs[x%2]) for x in np.arange(8)] 126 | # title = "Retrain only the last two layers" 127 | ### 128 | # fs = [['/home/heavens/bridge_scratch/Xron_models_NAIVT+100METH+90METH_LayerNorm_4000L_retrain_all/xron_model_%d_CM%d'%(x,x//2%2)] for x in np.arange(8)] 129 | # lrs = ['1e-4','4e-5'] 130 | # opts = ['Control-Methylation-Ratio %d:1, LR:%s'%(x//2%2,lrs[x%2]) for x in np.arange(8)] 131 | # title = "Retrain the whole NN." 132 | 133 | 134 | 135 | fig,axs = plt.subplots(figsize = (40,20)) 136 | FONTSIZE = 30 137 | for f,c,opt in zip(fs,colors,opts): 138 | errors = watch_errors(f) 139 | std_plot(errors,axs,color = c,label = opt, alpha = .2, smooth = 0.7) 140 | plt.legend() 141 | axs.set_xlabel("Training step",fontsize = FONTSIZE) 142 | axs.set_ylabel("Validate Error (Editdistance/Sequence Length)",fontsize = FONTSIZE) 143 | plt.legend(fontsize = FONTSIZE) 144 | axs.set_ylim([0,1.0]) 145 | plt.xticks(fontsize = FONTSIZE) 146 | plt.yticks(fontsize = FONTSIZE) 147 | plt.title(title,fontsize = FONTSIZE) 148 | prefix = "/home/heavens/bridge_scratch/Xron_models_merge_d/xron_model_supervised_control_dataset_Adagrad_16G" -------------------------------------------------------------------------------- /xron/xron_label.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Sun Feb 21 05:22:32 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | # from bwapy import BwaAligner 7 | # from bwapy.libbwa import Alignment 8 | import sys 9 | if 'bwapy' in sys.modules: 10 | raise ImportError("Loading bwapy again will cause error, please restart " \ 11 | "the process.") 12 | import bwapy 13 | import re 14 | import numpy as np 15 | from numpy import ndarray 16 | from typing import List, Optional 17 | from xron.utils.seq_op import raw2seq 18 | from itertools import permutations 19 | 20 | class MetricAligner(bwapy.BwaAligner): 21 | def __init__(self, 22 | reference: str, 23 | min_len: int = 7, 24 | options:str = None 25 | ): 26 | """ 27 | The wrapper aligner class of bwa aligner, measure the quality for 28 | batch of output sequences from a basecall. 29 | 30 | Parameters 31 | ---------- 32 | reference : str 33 | File path of a reference fasta/fastq file. 34 | min_len : int, optional 35 | The minimum sequence length to be aligned. The default is 7. 36 | 37 | """ 38 | if not options: 39 | options = '-x ont2d' 40 | # options = '-A 1 -B 1 -k %d -O 3 -T 0'%(min_len) 41 | super().__init__(index = reference,options = options) 42 | 43 | def permute_align(self, 44 | raw: ndarray, 45 | alphabet:str = "NACGT", 46 | mod_base:dict = {}, 47 | min_len:int = 5, 48 | permute:bool = True 49 | )->ndarray: 50 | """ 51 | Align the input sequence with best permutation. 52 | 53 | Parameters 54 | ---------- 55 | raw : A [N,L] ndarray gives the raw ctc sequence. 56 | where N is the number of sequences and L is the length of each 57 | sequence, where 0 alaways mean the blank symbol. 58 | permute: List[str] Optional 59 | A permute list, default is None, will try every possible permutation. 60 | mod_base: Dict Optional 61 | A dict mapping the modification base to the original base. 62 | Returns 63 | ------- 64 | ndarray 65 | A arra with same length of the sequences, gives the identity scores. 66 | 67 | """ 68 | 69 | ab = alphabet[1:] 70 | seqs = raw2seq(raw) 71 | len_score = np.asarray([len(seq) for seq in seqs])-min_len 72 | len_score[len_score<0] = 0 73 | scores = [] 74 | perms = [] 75 | if permute: 76 | for perm in permutations(ab): 77 | perm_seqs = ["".join([perm[x-1] for x in seq]) for seq in seqs] 78 | identities = self.align(perm_seqs) 79 | identities = identities * len_score 80 | scores.append(identities) 81 | perms.append(perm) 82 | else: 83 | seqs = ["".join([ab[x-1] for x in seq]) for seq in seqs] 84 | identities = self.align(seqs) 85 | identities = identities * len_score 86 | scores.append(identities) 87 | perms.append([x for x in ab]) 88 | return scores, perms 89 | 90 | 91 | def align(self, 92 | sequences: List[str] 93 | ) -> ndarray: 94 | """ 95 | Get the align score given a batch of sequences. 96 | 97 | Parameters 98 | ---------- 99 | sequences : List[str] 100 | A list of the sequences. 101 | Returns 102 | ------- 103 | ndarray 104 | A array with same length of the sequences, gives the identity 105 | scores. 106 | 107 | """ 108 | identities = [] 109 | for seq in sequences: 110 | hits = self.align_seq(seq) 111 | identity = self._match_score(hits[0])/len(seq) if len(hits)>0 else 0 #Only use the first hit. 112 | identities.append(identity) 113 | return np.array(identities) 114 | 115 | def _match_score(self, alignment: Optional[bwapy.libbwa.Alignment]) -> int: 116 | """ 117 | Get the match score given a cigar string. 118 | 119 | Parameters 120 | ---------- 121 | alignment : Optional[Alignment] 122 | A alignment instance from the bwapy mapping result. 123 | 124 | Returns 125 | ------- 126 | int 127 | The nubmer of matched bases. 128 | 129 | """ 130 | cigar_s = alignment.cigar 131 | match_hits = re.findall(r'(\d+)M', cigar_s) 132 | n_match = sum([int(x) for x in match_hits]) 133 | return n_match - alignment.NM 134 | 135 | if __name__ == "__main__": 136 | reference_f = '/home/heavens/twilight_hdd1/m6A_Nanopore/references.fasta' 137 | test_seqs = ['AAGAUUGUGUUGGUCAGUACUCGGUCGACGGGGUGAUUCCUUCAGCCCAGCCUGAUGAAUGCCGACACCUACCGGCUUUUAAGACGCCCAGGUACCAACGAUCAUCGUCGACCGUAAGU', 138 | 'AUAGAUUGUGUUUGUUAGUCGCUUGGUACGCCGAGCUGAAUGAUUGUGGAUAUCGCGAAGCAGGUUGAGCGCAACCCUGGACGACCGUAUUUCGGCGUCAAUAAGCGUCGUGAAGGCAGGUGAAUAUGAGCAUGGCGUUGGGGCUUUCCGAUAUCAAAGGCGCAGAUGAGCGGUCUGGUAUACCGUGGUUUGUGACUGGCAACGGCUUUAACGUUGGCGAUACGCUGCGUUAUCUGCAGGCAAUUGAGCGGCUGGAGAAAAACUGGCUGUCGAUCCGUCCGAACCGGGCGCGAAAUGCCGAAAGUCGAGCGUCCAGCAGGCGUGGCAUAAUGGAUAAUAAAACUGCCCAGCGCGUCGCGAAGAUGAUGAUGUUAAAGAGAUCCGUUUGGAUGAUUGAAGAACUGCGUGUCAGUAUUCGCCCAACGGUUACGGUGCCUUAUCCUAUCUCCGAUAAGCGUAUUUUACAGGCCAUGGAUCAGAUUACGGCCUAAAGCCAGGCGACGAUGAGCUAAAGAGAUCCAGUACAAGUUCGGAACAAAAAGGAAAUUAUGAAUACAACGCCAUUCCAGCGCUUUAAUGAAACCAUUGCUGGCGGUGAAUACACCCAAUGACAGUGGCUGGCGCAGCGUGAUUUUUCCGGCGAACAACCAUCAUCGGAUUCGCGACCGGCGGGCAAUGCGGUGAUCCCACAUCGACGCGGCCUGUCGCAUCGCCAAGCGCAGGCGUUCCACUGAUUGCGGCGGUAUAAUGGGUCACGCCUUCGACGCCGUUUUUGCGCCGUUGAUUGCCCGGCAUCCCGCGCUACCACACUAUCUGCCAACCGGACGCGCUGAAGCCAUCGAUCUCGCGGAUCGCGAACCAGUCUGGCAUAUUCCGGCUGAGAAAUCUGGCUGGAAGGUCGGUCAACUAUUGCGGCGAAAAUGCCCGUCACCUGCGUCAAUCCGCCAGGCGAAAGAAAACAUUAACGGCUAUCGUUGCGAAGACCCCACCAUGCAGCAGCGCACUAUCGCGUUAUUCCGGCGCGUGACAAACGACGACACCGAUGCGUGCUGGCUGAGUUUUUCCUGGUUUGUUCCAGUACUGCGCCACCUUAACGGCGAUACACGUUUUGCCUGUGUUGAAGAGGGCAUGGGACGGUGGAGCGUUAUCUGUACCGAGAUUGCCGGCGAGCUGCCUGCGUGAUGAAACAGGGUAAGCACCGCGCGGUAAAGAUUUUAUUUCACGUUGGAUAUCCCCACACGCGAUUAUUAGAAACGCAUUGGCGGUAUUAGCAGGCAGAUACGACGCUGCGCAGUGCAAUUAGAACAGAGAGCGUUCGCGAAAAAAUCCCGUCUGCGGGUGAACGGGCAAGUGCGACCACCAUAUCUCAUAAAUUUAAUGAAUAUUUAUAAAAGCAAAAUCGUUGUCGCUCACGGUUUCAUUCAGGACGCGCUAUGGGCGGCAAGUAUUCCGGCCUGCAAAUUGGUAUUUUACUGGUUAAGU', 139 | 'AGACUUGUGUUUGGCGGUCGCUUCUAAACCAGCGUAAAAGA', 140 | 'CAGAUAGAUAGUGUUUGUUAGUCGCUUUGGUGCGAAACUCCACUUAUCUACUUACUACGGUCAAAUUAACGAUCAAGCUUUAAGUACGUCGCCAGAUUGUCUGGUUUAACCCUCAAGUGCGGGACAAUUAUUGGGACGAAUACUACGAGUACGGUCGUGUGGCCAUCUCCACCGUCCCCCGCUGUAAUGCUUUUUUUACUCGUCUUAAGGAUAGGUAUCAUCUGACUAGUACCGGCGUGCAAAAGCUUUUUCCAGGAAAGUAGUCAGGAUGAGGCUGGUGAACUGUUACGAAAGAUGAAGGUAAGACCUCAUCCUACAUUUCACAGGCAAGCCAUUAAAGUGCACUGUCUCGUGGCGAUUGCAACUUCGAUCGGCAUCUGAUCAUUGGCCAGCCCUUAGCCGUGUUUCGGUGUGAAGCAGCGUGGCGGGAAUCGCCAUCAUGAUCCCCGCAUGUAGGGCCCUAUCGGAAUUACACCGUCCUAAUGGUUUUUGAAUUGCGGUGGCGCCAGAAAAGA', 141 | 'AGAUUGUGUUUGUUAGUUGCUUAACAAAAUGGUACAGAGAUGAUAGAAAGGCCUGGAGCCGUCUUCCGGGAAUCCAGAACAGAAUCAUCAACAUUUCCACGAAUCUCAAAGAGCUUAUUGAUUUUGGGCUGAGGCAGCUUGGAAGAAAGCAAUCUCCCAGGCAAAAGCAGCUGUUUUUUUAUCCGUCGAACGGUACUGCACUGCUGUCAUGGGCCGACAGGUUCCAGUAGCUCCUUACCAUGCGAACCGGAUCUACCAGUUUGCCAUCUACUCUGACAUUUUGAAGCUCAGGUGCACCAUUUAUUACACUAACUGUGCAGAUUACCAAUCCAAACUGCUUUCAAUUCGGACACAAAAAAAAGGUUUUACAGGUGCCAGCGAGGAUAAGGCUGGUUUAGUCGAACAAGCCAAUGGAGCAUACACUCUUAUGGACGGCUCCGCCUCCCAAAGGCAAGAAAUGUUGUCUUGCUACUGACAGCGGUUCACCACAACGGCUUGGAGAAUCCUGAACAUAAACGAACAUCAGACGUAUUGUUUAUUUGUGCCACUACAGAAAACCUAGGUCUGCAACUGAAAACAUUUUUGCAGAAUCCCAAUGACCAUUUUAACCAUGUUGUAGAACGAUCGCUUAAAGAAAAGGGUGGAUCUGACGACGUUUUAUUUAGGAAAGAAGCAGAACGAAUUAAGAAAAAUCUGAGUGUGCACAUAGAUGUCUAUAAUGCGCUUAUUCAUUCUGCGAAAUUUGAAAUGUGGGACAGCUGAAAUCAAAUGUCCAGUUAGUUUGUGCACACAGGGUUUUUGUUCCAACCUUGACGGACCGAGGUUAAUGAAAUGACUGUUCGGGAUCUGCCGGAUGAAAUCAAGCAAGAAUGGAUGUCCAGCAGCAAAAUACUGCAAAGGAGCAAAGCCAUUUGAAUGUCAUAAUUUUUUCGAUCAUCUCCGAUUAAGCGGAAGAUGAAACAACCGAAUAUUGACGAGGACCCUUUUCAUUUAACUUGUAUCAUAUUGCGAUCGAAGAAAAGGAG', 142 | 'AGAUUGUGUUUGUUAGUUGCUUAACAAAAUGGUACAGAGAUGAUAGAAAGGCCUGGAGCCGUCUUCCGGGAAUCCAGAACAGAAUCAUCAACAUUUCCACGAAUCUCAAAGAGCUUAUUGAUUUUGGGCUGAGGCAGCUUGGAAGAAAGCAAUCUCCCAGGCAAAAGCAGCUGUUUUUUUAUCCGUCGAACGGUACUGCACUGCUGUCAUGGGCCGACAGGUUCCAGUAGCUCCUUACCAUGCGAACCGGAUCUACCAGUUUGCCAUCUACUCUGACAUUUUGAAGCUCAGGUGCACCAUUUAUUACACUAACUGUGCAGAUUACCAAUCCAAACUGCUUUCAAUUCGGACACAAAAAAAAGGUUUUACAGGUGCCAGCGAGGAUAAGGCUGGUUUAGUCGAACAAGCCAAUGGAGCAUACACUCUUAUGGACGGCUCCGCCUCCCAAAGGCAAGAAAUGUUGUCUUGCUACUGACAGCGGUUCACCACAACGGCUUGGAGAAUCCUGAACAUAAACGAACAUCAGACGUAUUGUUUAUUUGUGCCACUACAGAAAACCUAGGUCUGCAACUGAAAACAUUUUUGCAGAAUCCCAAUGACCAUUUUAACCAUGUUGUAGAACGAUCGCUUAAAGAAAAGGGUGGAUCUGACGACGUUUUAUUUAGGAAAGAAGCAGAACGAAUUAAGAAAAAUCUGAGUGUGCACAUAGAUGUCUAUAAUGCGCUUAUUCAUUCUGCGAAAUUUGAAAUGUGGGACAGCUGAAAUCAAAUGUCCAGUUAGUUUGUGCACACAGGGUUUUUGUUCCAACCUUGACGGACCGAGGUUAAUGAAAUGACUGUUCGGGAUCUGCCGGAUGAAAUCAAGCAAGAAUGGAUGUCCAGCAGCAAAAUACUGCAAAGGAGCAAAGCCAUUUGAAUGUCAUAAUUUUUUCGAUCAUCUCCGAUUAAGCGGAAGAUGAAACAACCGAAUAUUGACGAGGACCCUUUUCAUUUAACUUGUAUCAUAUUGCGAUCGAAGAAAAGGAG'] 143 | 144 | aln = MetricAligner(reference_f) 145 | scores = aln.align(test_seqs) 146 | print(scores) 147 | -------------------------------------------------------------------------------- /xron/test_VQVAE_speech.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Mon Dec 27 16:35:56 2021 3 | 4 | @author: Haotian Teng 5 | """ 6 | import torch 7 | import torchaudio 8 | """ 9 | This script train a VQ-VAE style embedding network on Nanopore sequencing signal. 10 | @author: Haotian Teng 11 | """ 12 | import os 13 | import sys 14 | import torch 15 | import argparse 16 | import numpy as np 17 | from typing import Union,List 18 | from itertools import chain 19 | from torchvision import transforms 20 | import torch.utils.data as data 21 | from torch.utils.data.dataloader import DataLoader 22 | from xron.xron_input import Dataset, ToTensor, NumIndex,rna_filt,dna_filt 23 | from xron.xron_train_base import Trainer, DeviceDataLoader, load_config 24 | from xron.xron_model import REVCNN,DECODER_CONFIG,CRNN,CRITIC_CONFIG,CRITIC,MM_CONFIG,MM,vq 25 | from xron.xron_label import MetricAligner 26 | from torch.distributions.one_hot_categorical import OneHotCategorical as OHC 27 | 28 | class VQVAETrainer(Trainer): 29 | def __init__(self, 30 | train_dataloader:DataLoader, 31 | encoder:CRNN, 32 | decoder:REVCNN, 33 | mm:MM, 34 | config:Union[DECODER_CONFIG,MM_CONFIG], 35 | device:str = None, 36 | eval_dataloader:DataLoader = None): 37 | """ 38 | 39 | Parameters 40 | ---------- 41 | train_dataloader : DataLoader 42 | Training dataloader. 43 | encoder: CRNN 44 | A Convolutional-Recurrent Neural Network 45 | decoder : REVCNN 46 | REVCNN decoder 47 | mm: MM 48 | Markov Model instance. 49 | device: str 50 | The device used to train the model, can be 'cpu' or 'cuda'. 51 | Default is None, use cuda device if it's available. 52 | config: Union[DECODER_CONFIG,MM_CONFIG] 53 | A CONFIG class contains unsupervised training configurations. Need 54 | to contain at least these parameters: keep_record, device and 55 | grad_norm. 56 | eval_dataloader : DataLoader, optional 57 | Evaluation dataloader, if None training dataloader will be used. 58 | The default is None. 59 | 60 | """ 61 | super().__init__(train_dataloader=train_dataloader, 62 | nets = {"encoder":encoder, 63 | "decoder":decoder, 64 | "mm":mm}, 65 | config = config, 66 | device = device, 67 | eval_dataloader = eval_dataloader) 68 | self.train_config = config.TRAIN 69 | self.global_step = 0 70 | self.score_average = 0 71 | self.nn_embd = vq 72 | self.mse_loss = torch.nn.MSELoss(reduction = "mean") 73 | self.records = {'rc_losses':[], 74 | 'rc_valid':[], 75 | 'embedding_loss':[], 76 | 'commitment_loss':[]} 77 | @property 78 | def encoder(self): 79 | return self.nets["encoder"] 80 | 81 | @property 82 | def decoder(self): 83 | return self.nets["decoder"] 84 | 85 | @property 86 | def mm(self): 87 | return self.nets["mm"] 88 | 89 | def train(self, 90 | epoches:int, 91 | optimizers:List[torch.optim.Optimizer], 92 | save_cycle:int, 93 | save_folder:str): 94 | """ 95 | Train the encoder-decodr nets. 96 | 97 | Parameters 98 | ---------- 99 | epoches : int 100 | Number of epoches to train. 101 | optimizers : List[torch.optim.Optimizer] 102 | A list of three optimizers, the first one is optimizer training the 103 | encoder parameters, the second one for decoder parameters and the 104 | third one is the optimizer for the embedding. 105 | save_cycle : int 106 | Save every save_cycle batches. 107 | save_folder : str 108 | The folder to save the model and training records. 109 | 110 | Returns 111 | ------- 112 | None. 113 | 114 | """ 115 | self.save_folder = save_folder 116 | self._save_config() 117 | records = self.records 118 | for epoch_i in range(epoches): 119 | for i_batch, batch in enumerate(self.train_ds): 120 | losses = self.train_step(batch) 121 | loss = losses[0] + losses[1] + self.train_config['alpha']*losses[2] 122 | for opt in optimizers: 123 | opt.zero_grad() 124 | loss.backward() 125 | for opt in optimizers: 126 | opt.step() 127 | if (self.global_step+1)%save_cycle==0: 128 | self.save() 129 | eval_i,valid_batch = next(enumerate(self.eval_ds)) 130 | with torch.no_grad(): 131 | valid_rc = self.valid_step(valid_batch) 132 | records["rc_valid"].append(valid_rc.detach().cpu().numpy()[()]) 133 | records['rc_losses'].append(losses[0].detach().cpu().numpy()[()]) 134 | records['embedding_loss'].append(losses[1].detach().cpu().numpy()[()]) 135 | records['commitment_loss'].append(losses[2].detach().cpu().numpy()[()]) 136 | print("Epoch %d Batch %d, rc_loss %f, embedding_loss %f, validation rc %f"%(epoch_i, i_batch, losses[0], losses[1],valid_rc)) 137 | self._update_records() 138 | losses = None 139 | torch.nn.utils.clip_grad_norm_(self.parameters, 140 | max_norm=self.grad_norm) 141 | self.global_step +=1 142 | 143 | def train_step(self,batch): 144 | encoder = self.encoder 145 | decoder = self.decoder 146 | embedding = self.mm.level_embedding 147 | signal = batch['signal'] 148 | q = encoder.forward_wo_fnn(signal) #[N,C,L] 149 | e,e_shadow = self.nn_embd(q.permute([0,2,1]),embedding) #[N,L,C] -> [N,C,L] 150 | e = e.permute([0,2,1]) 151 | e_shadow = e_shadow.permute([0,2,1]) 152 | sg_q = q.detach() 153 | sg_e = e.detach() 154 | # q = q + torch.normal(torch.zeros(q.shape),std = self.train_config['sigma']) 155 | rc_signal = decoder.forward(e).permute([0,2,1]) #[N,L,C] -> [N,C,L] 156 | rc_loss = self.mse_loss(rc_signal,signal) 157 | embedding_loss = self.mse_loss(sg_q,e_shadow) 158 | commitment_loss = self.mse_loss(sg_e,q) 159 | return rc_loss, embedding_loss, commitment_loss 160 | 161 | def valid_step(self,batch): 162 | rc_loss,_,_ = self.train_step(batch) 163 | return rc_loss 164 | 165 | def main(args): 166 | class CTC_CONFIG(MM_CONFIG): 167 | CTC = {"beam_size":5, 168 | "beam_cut_threshold":0.05, 169 | "alphabeta": "ACGTM", 170 | "mode":"rna"} 171 | class TRAIN_CONFIG(CTC_CONFIG): 172 | TRAIN = {"inital_learning_rate":args.lr, 173 | "batch_size":args.batch_size, 174 | "grad_norm":2, 175 | "epsilon":0.1, 176 | "epsilon_decay":0, 177 | "alpha":1.0, #Entropy loss scale factor 178 | "keep_record":5, 179 | "decay":args.decay, 180 | "diff_signal":args.diff} 181 | 182 | config = TRAIN_CONFIG() 183 | config.PORE_MODEL["N_BASE"] = len(config.CTC["alphabeta"]) 184 | print("Construct and load the model.") 185 | model_f = args.model_folder 186 | loader = args.loader 187 | DEVICE = args.device 188 | loader = DeviceDataLoader(loader,device = DEVICE) 189 | if args.retrain: 190 | config_old = load_config(os.path.join(model_f,"config.toml")) 191 | config_old.TRAIN = config.TRAIN #Overwrite training config. 192 | config = config_old 193 | encoder = CRNN(config) 194 | decoder = REVCNN(config) 195 | mm = MM(config) 196 | t = VQVAETrainer(loader,encoder,decoder,mm,config) 197 | if args.retrain: 198 | t.load(model_f) 199 | lr = args.lr 200 | epoches = args.epoches 201 | opt_e = torch.optim.Adam(chain(t.encoder.parameters(),t.decoder.parameters()),lr = lr) 202 | opt_mm = torch.optim.SGD(t.mm.parameters(),lr = lr) 203 | COUNT_CYCLE = args.report 204 | print("Begin training the model.") 205 | t.train(epoches,[opt_e,opt_mm],COUNT_CYCLE,model_f) 206 | 207 | 208 | if __name__ == "__main__": 209 | class Args: 210 | pass 211 | base_f = '/home/heavens/twilight_data1/Yesno' 212 | os.makedirs(base_f,exist_ok = True) 213 | yn_data = torchaudio.datasets.YESNO(base_f,download=True) 214 | data_loader = torch.utils.data.DataLoader(yn_data, 215 | batch_size=10, 216 | shuffle=True, 217 | num_workers=1) 218 | args = Args() 219 | args.loader = data_loader 220 | args.model_folder = base_f + '/model' 221 | args.device = "cuda" 222 | args.lr = 4e-3 223 | args.batch_size = 100 224 | args.epoches = 3 225 | args.report = 10 226 | args.retrain = False 227 | args.decay = 0.99 228 | args.diff = False 229 | if not os.path.isdir(args.model_folder): 230 | os.mkdir(args.model_folder) 231 | main(args) 232 | -------------------------------------------------------------------------------- /xron/nrhmm/hmm_relabel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Fri May 20 14:53:39 2022 3 | 4 | @author: Haotian Teng 5 | """ 6 | 7 | import os 8 | import sys 9 | import toml 10 | import torch 11 | import argparse 12 | import itertools 13 | import numpy as np 14 | from tqdm import tqdm 15 | from typing import Dict,List 16 | from xron.utils.seq_op import Methylation_DP_Aligner 17 | from xron.nrhmm.hmm import GaussianEmissions, RHMM 18 | from xron.nrhmm.hmm_input import Kmer2Transition, Kmer_Dataset, Normalizer 19 | from torchvision import transforms 20 | from torch.utils.data.dataloader import DataLoader 21 | from xron.xron_train_base import DeviceDataLoader 22 | 23 | def kmers2seq(kmers,idx2kmer): 24 | merged = [g for g, _ in itertools.groupby(kmers)] 25 | seqs = [idx2kmer[x][0] for x in merged] 26 | return ''.join(seqs) + idx2kmer[merged[-1]][1:] 27 | 28 | def print_pore_model(emission_means,idx2kmer,f = None): 29 | for i,mean in enumerate(emission_means): 30 | if f: 31 | f.write("%s: %.2f\n"%(idx2kmer[i],mean)) 32 | else: 33 | print("%s: %.2f"%(idx2kmer[i],mean)) 34 | 35 | def get_effective_kmers(selection:str,idx2kmer:List): 36 | demand_char = [x for x in selection.split("!")[0]] 37 | filter_char = [x for x in selection.split("!")[1]] 38 | selected_kmers = [i for i,x in enumerate(idx2kmer) if all([d in x for d in demand_char] + [f not in x for f in filter_char])] 39 | return selected_kmers 40 | 41 | def load(self,save_folder,update_global_step = True): 42 | self.save_folder = save_folder 43 | ckpt_file = os.path.join(save_folder,'checkpoint') 44 | with open(ckpt_file,'r') as f: 45 | latest_ckpt = f.readline().strip().split(':')[1] 46 | if update_global_step: 47 | self.global_step = int(latest_ckpt.split('-')[1]) 48 | ckpt = torch.load(os.path.join(save_folder,latest_ckpt), 49 | map_location=self.device) 50 | for key,net in ckpt.items(): 51 | if key in self.nets.keys(): 52 | self.nets[key].load_state_dict(net,strict = False) 53 | self.nets[key].to(self.device) 54 | else: 55 | print("%s net is defined in the checkpoint but is not imported because it's not defined in the model."%(key)) 56 | 57 | def main(args): 58 | print("Loading data...") 59 | config = toml.load(os.path.join(args.input,"config.toml")) 60 | chunks = np.load(os.path.join(args.input,"chunks.npy"),mmap_mode=args.mmap_mode) 61 | n_samples, sig_len = chunks.shape 62 | durations = np.load(os.path.join(args.input,"durations.npy"),mmap_mode = args.mmap_mode) 63 | idx2kmer = config['idx2kmer'] 64 | kmers = np.load(os.path.join(args.input,"kmers.npy"),mmap_mode = args.mmap_mode) 65 | k2t = Kmer2Transition(alphabeta = config['alphabeta'], 66 | k = config['k'], 67 | T_max = config['chunk_len'], 68 | kmer2idx = config['kmer2idx_dict'], 69 | idx2kmer = config['idx2kmer'], 70 | neighbour_kmer = 4 , 71 | base_prior = {x:args.transition_prior for x in config['alphabeta']}, 72 | base_alternation = {"A":"M"}, 73 | kmer_replacement = args.kmer_replacement, 74 | out_format = args.transition_type) 75 | dataset = Kmer_Dataset(chunks, durations, kmers,transform=transforms.Compose([k2t])) 76 | loader = DataLoader(dataset,batch_size = args.batch_size, shuffle = False) 77 | loader = DeviceDataLoader(loader,device = args.device) 78 | 79 | print("Load the model.") 80 | ckpt_file = os.path.join(args.model,'checkpoint') 81 | with open(ckpt_file,'r') as f: 82 | latest_ckpt = f.readline().strip().split(':')[1] 83 | model = torch.load(os.path.join(args.model,latest_ckpt), 84 | map_location = torch.device(args.device)) 85 | emission = GaussianEmissions(model['hmm']['emission.means'].cpu().numpy(), model['hmm']['emission.cov'].cpu().numpy()) 86 | hmm = RHMM(emission, 87 | normalize_transition=False, 88 | device = args.device, 89 | transition_operation = args.transition_operation, 90 | index_mapping = k2t.idx_map if args.transition_operation == "compact" else None) 91 | 92 | print("Readout the pore model.") 93 | with open(os.path.join(args.model,"pore_model"), "w+") as f: 94 | print_pore_model(model['hmm']['emission.means'],config['idx2kmer'],f) 95 | 96 | aligner = Methylation_DP_Aligner(base_alternation = {'M':'A'}) 97 | if args.effective == "!": 98 | effective_kmers = None 99 | else: 100 | effective_kmers = get_effective_kmers(args.effective, idx2kmer) 101 | norm = Normalizer(use_dwell = True, 102 | effective_kmers=effective_kmers) 103 | sigs,seqs,paths = [],[],[] 104 | for i_batch, batch in tqdm(enumerate(loader)): 105 | renorm_batch = [] 106 | signal_batch = batch['signal'] 107 | duration_batch = batch['duration'] 108 | transition_batch = batch['labels'] 109 | kmers_batch = batch['kmers'] 110 | with torch.no_grad(): 111 | path,logit = hmm.viterbi_decode(signal_batch, duration_batch, transition_batch) 112 | paths.append(path.cpu().numpy()) 113 | renorm_batch = signal_batch 114 | for j in tqdm(np.arange(args.renorm+1),desc = "Renorm the signal:"): 115 | with torch.no_grad(): 116 | path,logit = hmm.viterbi_decode(renorm_batch, duration_batch, transition_batch) 117 | rc_signal = np.asarray([[hmm.emission.means[x].item() for x in p] for p in path.cpu().numpy()]) 118 | renorm_batch = norm(renorm_batch.cpu().numpy()[:,:,0],rc_signal, duration_batch.cpu().numpy(),path.cpu().numpy()) 119 | renorm_batch = torch.from_numpy(renorm_batch).unsqueeze(dim = -1).to(args.device) 120 | renorm_batch = renorm_batch.squeeze(dim = -1).cpu().numpy() 121 | sigs.append(renorm_batch) 122 | for i in np.arange(args.batch_size): 123 | if i >= signal_batch.shape[0]: 124 | continue 125 | deco_seq = kmers2seq(path.cpu().numpy()[i][:duration_batch[i]],idx2kmer) 126 | orig_seq = kmers2seq(kmers_batch.cpu().numpy()[i][:duration_batch[i]],idx2kmer) 127 | deco_aln,orig_aln = aligner.align(deco_seq,orig_seq) 128 | final_seq = aligner.merge(deco_aln,orig_aln) 129 | seqs.append(final_seq) 130 | if args.max_n: 131 | if len(seqs) >= args.max_n: 132 | break 133 | seq_lens = [len(i) for i in seqs] 134 | sigs = np.vstack(sigs) 135 | seqs = np.array(seqs) 136 | seq_lens = np.array(seq_lens) 137 | paths = np.concatenate(paths,axis = 0) 138 | np.save(os.path.join(args.input,'chunks_renorm.npy'),sigs[:args.max_n]) 139 | np.save(os.path.join(args.input,'seqs_re.npy'),seqs[:args.max_n]) 140 | np.save(os.path.join(args.input,'seq_re_lens.npy'),seq_lens[:args.max_n]) 141 | np.save(os.path.join(args.input,'path'),paths[:args.max_n]) 142 | 143 | def add_arguments(parser): 144 | parser.add_argument("-i","--input", type = str, required = True, 145 | help = "Data folder contains the chunk, kmer sequence.") 146 | parser.add_argument('-m', '--model', required = True, 147 | help = "The rhmm model folder.") 148 | parser.add_argument("-b","--batch_size", type = int, default = 20, 149 | help = "The batch size to train.") 150 | parser.add_argument("--renorm", type = int, default = 1, 151 | help = "Number of time to renormalize the signal.") 152 | parser.add_argument("--device",type = str, default = None, 153 | help = "The device used to train the model.") 154 | parser.add_argument("--max_n",type = int, default = None, 155 | help = "The maximum number of reads.") 156 | parser.add_argument("--certain_methylation",action = "store_false", 157 | dest = "kmer_replacement", 158 | help = "If we are sure about the methylation state.") 159 | parser.add_argument("--mmap_mode",type = str, default = None, 160 | help = "mmap mode when loding numpy data, default is\ 161 | None which does not enable mmapmode, can be r.") 162 | parser.add_argument("-e","--effective",type = str, default = "!", 163 | help = "A magic string gives the kmers that take into \ 164 | account when doing normalization, for example, \ 165 | A!M means kmers that must have A and must not \ 166 | have M is taken into account.") 167 | parser.add_argument("--transition_prior",type = float, default = 0.3, 168 | help = "The prior probability of transition matrix.") 169 | parser.add_argument("--transition_operation",type = str, default = "sparse") 170 | 171 | def post_args(args): 172 | if args.device is None: 173 | args.device = "cuda" if torch.cuda.is_available() else "cpu" 174 | if args.transition_operation == "compact": 175 | args.transition_type = "compact" 176 | else: 177 | args.transition_type = "sparse" 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser( 181 | description='Training RHMM model') 182 | args = parser.parse_args(sys.argv[1:]) 183 | main(args) 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /xron/xron_train_embedding.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script train a VQ-VAE style embedding network on Nanopore sequencing signal. 3 | @author: Haotian Teng 4 | """ 5 | import os 6 | import sys 7 | import torch 8 | import argparse 9 | import numpy as np 10 | from typing import Union,List 11 | from itertools import chain 12 | from torchvision import transforms 13 | import torch.utils.data as data 14 | from torch.utils.data.dataloader import DataLoader 15 | from xron.xron_input import Dataset, ToTensor, NumIndex 16 | from xron.utils.prepare_chunk import rna_filt,dna_filt 17 | from xron.xron_train_base import Trainer, DeviceDataLoader, load_config 18 | from xron.xron_model import REVCNN,DECODER_CONFIG,CRNN,CRITIC_CONFIG,CRITIC,MM_CONFIG,MM 19 | from xron.xron_label import MetricAligner 20 | from torch.distributions.one_hot_categorical import OneHotCategorical as OHC 21 | 22 | class VQVAETrainer(Trainer): 23 | def __init__(self, 24 | train_dataloader:DataLoader, 25 | encoder:CRNN, 26 | decoder:REVCNN, 27 | mm:MM, 28 | config:Union[DECODER_CONFIG,MM_CONFIG], 29 | device:str = None, 30 | eval_dataloader:DataLoader = None): 31 | """ 32 | 33 | Parameters 34 | ---------- 35 | train_dataloader : DataLoader 36 | Training dataloader. 37 | encoder: CRNN 38 | A Convolutional-Recurrent Neural Network 39 | decoder : REVCNN 40 | REVCNN decoder 41 | mm: MM 42 | Markov Model instance. 43 | device: str 44 | The device used to train the model, can be 'cpu' or 'cuda'. 45 | Default is None, use cuda device if it's available. 46 | config: Union[DECODER_CONFIG,MM_CONFIG] 47 | A CONFIG class contains unsupervised training configurations. Need 48 | to contain at least these parameters: keep_record, device and 49 | grad_norm. 50 | eval_dataloader : DataLoader, optional 51 | Evaluation dataloader, if None training dataloader will be used. 52 | The default is None. 53 | 54 | """ 55 | super().__init__(train_dataloader=train_dataloader, 56 | nets = {"encoder":encoder, 57 | "decoder":decoder, 58 | "mm":mm}, 59 | config = config, 60 | device = device, 61 | eval_dataloader = eval_dataloader) 62 | self.train_config = config.TRAIN 63 | self.global_step = 0 64 | self.score_average = 0 65 | self.mse_loss = torch.nn.MSELoss(reduction = "mean") 66 | self.records = {'rc_losses':[], 67 | 'rc_valid':[], 68 | 'embedding_loss':[], 69 | 'commitment_loss':[]} 70 | @property 71 | def encoder(self): 72 | return self.nets["encoder"] 73 | 74 | @property 75 | def decoder(self): 76 | return self.nets["decoder"] 77 | 78 | @property 79 | def mm(self): 80 | return self.nets["mm"] 81 | 82 | def train(self, 83 | epoches:int, 84 | optimizers:List[torch.optim.Optimizer], 85 | save_cycle:int, 86 | save_folder:str): 87 | """ 88 | Train the encoder-decodr nets. 89 | 90 | Parameters 91 | ---------- 92 | epoches : int 93 | Number of epoches to train. 94 | optimizers : List[torch.optim.Optimizer] 95 | A list of three optimizers, the first one is optimizer training the 96 | encoder parameters, the second one for decoder parameters and the 97 | third one is the optimizer for the embedding. 98 | save_cycle : int 99 | Save every save_cycle batches. 100 | save_folder : str 101 | The folder to save the model and training records. 102 | 103 | Returns 104 | ------- 105 | None. 106 | 107 | """ 108 | self.save_folder = save_folder 109 | self._save_config() 110 | records = self.records 111 | for epoch_i in range(epoches): 112 | for i_batch, batch in enumerate(self.train_ds): 113 | losses = self.train_step(batch) 114 | loss = losses[0] + losses[1] + self.train_config['alpha']*losses[2] 115 | for opt in optimizers: 116 | opt.zero_grad() 117 | loss.backward() 118 | for opt in optimizers: 119 | opt.step() 120 | if (self.global_step+1)%save_cycle==0: 121 | self.save() 122 | eval_i,valid_batch = next(enumerate(self.eval_ds)) 123 | with torch.no_grad(): 124 | valid_rc = self.valid_step(valid_batch) 125 | records["rc_valid"].append(valid_rc.detach().cpu().numpy()[()]) 126 | records['rc_losses'].append(losses[0].detach().cpu().numpy()[()]) 127 | records['embedding_loss'].append(losses[1].detach().cpu().numpy()[()]) 128 | records['commitment_loss'].append(losses[2].detach().cpu().numpy()[()]) 129 | print("Epoch %d Batch %d, rc_loss %f, embedding_loss %f, validation rc %f"%(epoch_i, i_batch, losses[0], losses[1],valid_rc)) 130 | self._update_records() 131 | losses = None 132 | torch.nn.utils.clip_grad_norm_(self.parameters, 133 | max_norm=self.grad_norm) 134 | self.global_step +=1 135 | 136 | def train_step(self,batch): 137 | encoder = self.encoder 138 | decoder = self.decoder 139 | embedding = self.mm.level_embedding 140 | signal = batch['signal'] 141 | e,q,e_shadow,sg_q,sg_e = encoder.forward_embedding(signal,embedding = embedding) 142 | rc_signal = decoder.forward(e).permute([0,2,1]) #[N,L,C] -> [N,C,L] 143 | rc_loss = self.mse_loss(rc_signal,signal) 144 | embedding_loss = self.mse_loss(sg_q,e_shadow) 145 | commitment_loss = self.mse_loss(sg_e,q) 146 | return rc_loss, embedding_loss, commitment_loss 147 | 148 | def valid_step(self,batch): 149 | rc_loss,_,_ = self.train_step(batch) 150 | return rc_loss 151 | 152 | def main(args): 153 | class CTC_CONFIG(MM_CONFIG): 154 | CTC = {"beam_size":5, 155 | "beam_cut_threshold":0.05, 156 | "alphabeta": "ACGTM", 157 | "mode":"rna"} 158 | class TRAIN_CONFIG(CTC_CONFIG): 159 | TRAIN = {"inital_learning_rate":args.lr, 160 | "batch_size":args.batch_size, 161 | "grad_norm":2, 162 | "epsilon":0.1, 163 | "epsilon_decay":0, 164 | "alpha":1.0, #Entropy loss scale factor 165 | "keep_record":5, 166 | "decay":args.decay, 167 | "diff_signal":args.diff} 168 | 169 | config = TRAIN_CONFIG() 170 | config.PORE_MODEL["N_BASE"] = len(config.CTC["alphabeta"]) 171 | print("Read chunks and sequence.") 172 | chunks = np.load(args.chunks,allow_pickle = True,mmap_mode= 'r') 173 | print("Construct and load the model.") 174 | model_f = args.model_folder 175 | dataset = Dataset(chunks,seq = None,seq_len = None,transform = transforms.Compose([ToTensor()])) 176 | loader = data.DataLoader(dataset,batch_size = args.batch_size,shuffle = True, num_workers = 4) 177 | DEVICE = args.device 178 | loader = DeviceDataLoader(loader,device = DEVICE) 179 | if args.retrain: 180 | config_old = load_config(os.path.join(model_f,"config.toml")) 181 | config_old.TRAIN = config.TRAIN #Overwrite training config. 182 | config = config_old 183 | if args.config: 184 | config_old = load_config(args.config) 185 | config_old.TRAIN = config.TRAIN 186 | config = config_old 187 | encoder = CRNN(config) 188 | decoder = REVCNN(config) 189 | mm = MM(config) 190 | t = VQVAETrainer(loader,encoder,decoder,mm,config) 191 | if args.retrain: 192 | t.load(model_f) 193 | lr = args.lr 194 | epoches = args.epoches 195 | opt = torch.optim.Adam(chain(t.encoder.parameters(),t.decoder.parameters(),t.mm.level_embedding.parameters()),lr = lr) 196 | COUNT_CYCLE = args.report 197 | print("Begin training the model.") 198 | t.train(epoches,[opt],COUNT_CYCLE,model_f) 199 | 200 | 201 | if __name__ == "__main__": 202 | parser = argparse.ArgumentParser( 203 | description='Training model with chunks and sequence file') 204 | parser.add_argument('-i', '--chunks', required = True, 205 | help = "The .npy file contain chunks.") 206 | parser.add_argument('-o', '--model_folder', required = True, 207 | help = "The folder to save folder at.") 208 | parser.add_argument('--config', default = None, 209 | help = "The configure file used to train.") 210 | parser.add_argument('--device', default = 'cuda', 211 | help="The device used for training, can be cpu or cuda.") 212 | parser.add_argument('--lr', default = 4e-3, type = float, 213 | help="Initial learning rate.") 214 | parser.add_argument('--batch_size', default = 200, type = int, 215 | help="Training batch size.") 216 | parser.add_argument('--epoches', default = 10, type = int, 217 | help = "The number of epoches to train.") 218 | parser.add_argument('--report', default = 10, type = int, 219 | help = "The interval of training rounds to report.") 220 | parser.add_argument('--load', dest='retrain', action='store_true', 221 | help='Load existed model.') 222 | parser.add_argument('--decay', type = float, default = 0.99, 223 | help="The decay factor of the moving average.") 224 | parser.add_argument('--diffSig',action="store_true",dest = "diff", 225 | help="If the input chunks are diffrential signal.") 226 | args = parser.parse_args(sys.argv[1:]) 227 | args.seq = None 228 | args.seq_len = None 229 | if not os.path.isdir(args.model_folder): 230 | os.mkdir(args.model_folder) 231 | main(args) 232 | -------------------------------------------------------------------------------- /xron/utils/tagging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Thu Oct 6 22:09:32 2022 5 | 6 | @author: heavens 7 | """ 8 | import os 9 | import re 10 | import sys 11 | import h5py 12 | import numpy as np 13 | import argparse 14 | import pathlib 15 | from tqdm import tqdm 16 | from xron import __version__ 17 | FILE = pathlib.Path(__file__).resolve() 18 | FLAGS = {"5mC":"C+m","5hmC":"C+h","5fC":"C+f","5caC":"C+c", 19 | "5hmU":"T+g","5fU":"T+e","5caU":"T+b", 20 | "6mA":"A+a", 21 | "8oxoG":"G+o", 22 | "Xao":"N+n"} 23 | def check_index(fastq): 24 | if not os.path.exists(fastq+'.index'): 25 | raise FileNotFoundError('fastq index file not found') 26 | 27 | def read_index(fastq): 28 | check_index(fastq) 29 | index_dict = {} 30 | with open(fastq+'.index') as f: 31 | index = f.read().splitlines() 32 | index_dict[index[0]] = index[1] 33 | return index_dict 34 | 35 | def int8_encode(x): 36 | assert np.all(0 <= x) and np.all(x<= 1) 37 | return (x*256).astype(np.uint8) 38 | 39 | class SamParser(object): 40 | def __init__(self,sam_file,fastq_index,modified,format='merge'): 41 | self.sam_file = sam_file 42 | self.fastq_index = fastq_index 43 | self.tagged = False #To identify if the file has been tagged by xron before. 44 | self.modified = modified 45 | self.flag = FLAGS[modified] 46 | self.canonical_base = self.flag.split('+')[0] 47 | self.format = format 48 | @property 49 | def PG_header(self): 50 | return '@PG\tID:xron_tagging\tPN:xron\tVN:%s\tCL:%s' % (__version__, 'python '+str(FILE)+' '+' '.join(sys.argv[1:])) 51 | 52 | def _remove_exisiting_xron_header(self): 53 | for header_line in self.headers: 54 | if 'ID:xron_tagging' in header_line: 55 | print("Found existing xron header, update it") 56 | self.headers.remove(header_line) 57 | self.tagged = True 58 | 59 | def add_header(self): 60 | self._remove_exisiting_xron_header() 61 | for i,h in enumerate(self.headers): 62 | #Edit the PP tag in the first PG header to create chain 63 | if h.startswith('@PG') and "PP:xron_tagging" not in h: 64 | if "PP:" in h: 65 | raise ValueError("PP tag already exists in the first PG header") 66 | split_h = h.split('\t') 67 | h = split_h[:2] + ['PP:xron_tagging'] + split_h[2:] 68 | h = '\t'.join(h) 69 | self.headers[i] = h 70 | break 71 | self.headers = [self.PG_header] + self.headers 72 | 73 | def read_fastq_index(self): 74 | read_ids, read_mappings = [],[] 75 | with open(self.fastq_index,'r') as f: 76 | for line in f: 77 | split_line = line.strip().split('\t') 78 | read_ids.append(split_line[0]) 79 | read_mappings.append(split_line[1]) 80 | self.read_ids = np.asarray(read_ids) 81 | self.read_mappings = np.asarray(read_mappings) 82 | 83 | def read_sam(self): 84 | with open(self.sam_file,'r') as f: 85 | sam = f.read().splitlines() 86 | self.headers,self.alignment = self._parse_header(sam) 87 | self.alignment_ids = [x.split('\t')[0] for x in self.alignment] 88 | 89 | def _parse_header(self,sam): 90 | return [line for line in sam if line.startswith('@')],[line for line in sam if not line.startswith('@')] 91 | 92 | def parse_MMflag(self,aln): 93 | split_aln = aln.strip().split('\t') 94 | MMs = {} 95 | MLs = {} 96 | for field in split_aln: 97 | if field.startswith('MM:Z:'): 98 | MM_tags = field[5:].strip(';').split(';') 99 | for MM_tag in MM_tags: 100 | codes,positions = self.parse_MMtag(MM_tag) 101 | for code,position in zip(codes,positions): 102 | MMs[code] = position 103 | if field.startswith("ML:B:C"): 104 | split_ML = field[7:].split(',') 105 | assert len(split_ML) == sum(len(x.split(',')) for x in MMs.values()) 106 | for key,val in MMs.items(): 107 | c = len(val.split(',')) 108 | MLs[key] = split_ML[:c] 109 | del split_ML[:c] 110 | return MMs,MLs 111 | 112 | def parse_MMtag(self,MM_tag): 113 | #Parse single MM tag for each modification 114 | MM_tag_split = MM_tag.split(',') 115 | title = MM_tag_split[0].strip('.') 116 | unmo_base,mo_base = title.split('+') 117 | mo_codes,mo_pos = [],[] 118 | for m in mo_base: 119 | mo_codes.append(unmo_base+'+'+m) 120 | mo_pos.append(','.join(MM_tag_split[1:])) 121 | return mo_codes,mo_pos 122 | 123 | 124 | def _remove_exisiting_MMrecord(self,MMs,MLs): 125 | if self.flag in MMs.keys(): 126 | del MMs[self.flag] 127 | if self.flag in MLs.keys(): 128 | del MLs[self.flag] 129 | 130 | def _generate_MMtag(self,int8_modified_probability): 131 | #Generate MM tag for each modification 132 | if len(int8_modified_probability) == 0: 133 | return '','' 134 | pos = np.where(int8_modified_probability)[0] 135 | shift = pos[1:] - pos[:-1] 136 | shift -= 1 137 | if len(pos): 138 | gap_count = np.append([pos[0]],shift) 139 | else: 140 | gap_count = [] 141 | mm = ','.join([str(x) for x in gap_count]) 142 | ml = ','.join([str(x) for x in int8_modified_probability[pos]]) 143 | return mm,ml 144 | 145 | def _generate_MMtag_nomerge(self,int8_modified_probability): 146 | #Generate MM tag with out merging 0 probability. 147 | if len(int8_modified_probability) == 0: 148 | return '','' 149 | mm = ','.join(['0']*len(int8_modified_probability)) 150 | ml = ','.join([str(x) for x in int8_modified_probability]) 151 | return mm,ml 152 | 153 | def read_modified(self,fast5_read_handle): 154 | modified_probability = None 155 | for entry in fast5_read_handle['Analyses'].keys(): 156 | if not entry.startswith('Basecall_'): 157 | continue #Skip non-basecall entries 158 | if fast5_read_handle['Analyses'][entry].attrs['name'] != 'Xron': 159 | continue 160 | result_h = fast5_read_handle['Analyses'][entry]['BaseCalled_template'] 161 | try: 162 | seq = str(np.asarray(result_h['Fastq']).astype(str)).split('\n')[1] 163 | except: 164 | seq = np.asarray(result_h['Fastq']).tobytes().decode('utf-8').split('\n')[1] 165 | try: 166 | modified_probability = np.asarray(result_h['ModifiedProbability']) 167 | if (len(modified_probability) == 0) and (seq.count(self.canonical_base)): 168 | return None 169 | return int8_encode(modified_probability) 170 | except KeyError: 171 | continue 172 | return None 173 | 174 | def __call__(self): 175 | self.read_sam() 176 | self.read_fastq_index() 177 | self.add_header() 178 | fail_count = 0 179 | flag = self.flag 180 | mmtag_func = self._generate_MMtag if self.format == 'merge' else self._generate_MMtag_nomerge 181 | new_sam = self.headers 182 | uniq_file_list = set(self.read_mappings) 183 | with tqdm() as t: 184 | for fast5f in uniq_file_list: 185 | with h5py.File(fast5f,'r') as root: 186 | for read_id in self.read_ids[self.read_mappings==fast5f]: 187 | try: 188 | idx = self.alignment_ids.index(read_id) 189 | except ValueError: 190 | fail_count +=1 191 | continue 192 | aln = self.alignment[idx] 193 | MMs,MLs = self.parse_MMflag(aln) 194 | self._remove_exisiting_MMrecord(MMs, MLs) 195 | read_id = aln.split('\t')[0] 196 | t.postfix = "Read_id: %s, failed: %d"%(read_id,fail_count) 197 | t.update() 198 | read_h = root['read_' + read_id] 199 | modified_p = self.read_modified(read_h) 200 | if modified_p is None: 201 | fail_count += 1 202 | else: 203 | mm,ml = mmtag_func(modified_p) 204 | MMs[self.flag] = mm 205 | MLs[self.flag] = ml 206 | if len(MMs): 207 | MM_string = 'MM:Z:'+';'.join([k+'.,'+v for k,v in MMs.items()])+';' 208 | ML_string = 'ML:B:C,'+','.join([v for v in MLs.values()]) 209 | if 'MM:Z:' in aln: 210 | aln = re.sub('MM:Z:.*?\t',MM_string+'\t',aln) 211 | else: 212 | aln = aln + '\t' + MM_string +'\t' 213 | if 'ML:B:C' in aln: 214 | aln = re.sub('ML:B:C.*?;',ML_string+';',aln) 215 | else: 216 | aln = aln + ML_string+';' 217 | new_sam.append(aln) 218 | new_sam_file = os.path.splitext(self.sam_file)[0] + ".tagged.sam" 219 | with open(new_sam_file + '','w') as f: 220 | f.write('\n'.join(new_sam)) 221 | 222 | def main(args): 223 | if not args.fastq.endswith('.index'): 224 | args.fastq += '.index' 225 | if args.merge: 226 | format = "merge" 227 | else: 228 | format = "flatten" 229 | sam_writer = SamParser(args.sam,args.fastq,args.modified,format = format) 230 | sam_writer.read_sam() 231 | sam_writer() 232 | 233 | if __name__ == "__main__": 234 | parser = argparse.ArgumentParser(description='Add modification tag into sam file.') 235 | parser.add_argument('--fastq', required = True, type=str, help='The merged fastq file') 236 | parser.add_argument('--sam', required = True, type=str, help='The sam file') 237 | parser.add_argument('--modified',default = "6mA", type=str, help='The modified base, \ 238 | can be one of the %s'%list(FLAGS.keys())) 239 | parser.add_argument('--merge',action = "store_true",dest = "merge", 240 | help = "Set the output MM tag format to compact format.") 241 | args = parser.parse_args() 242 | main(args) 243 | --------------------------------------------------------------------------------